Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,9 +7,8 @@ subprocess.run(
|
|
| 7 |
shell=True,
|
| 8 |
)
|
| 9 |
|
| 10 |
-
from huggingface_hub import snapshot_download
|
| 11 |
-
|
| 12 |
os.makedirs("/home/user/app/checkpoints", exist_ok=True)
|
|
|
|
| 13 |
snapshot_download(
|
| 14 |
repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints"
|
| 15 |
)
|
|
@@ -32,8 +31,7 @@ import torch.distributed as dist
|
|
| 32 |
from torchvision.transforms.functional import to_pil_image
|
| 33 |
|
| 34 |
from PIL import Image
|
| 35 |
-
from
|
| 36 |
-
from threading import Thread, Barrier
|
| 37 |
|
| 38 |
import models
|
| 39 |
|
|
@@ -50,7 +48,6 @@ description = """
|
|
| 50 |
#### Demo current model: `Lumina-Next-T2I`
|
| 51 |
|
| 52 |
"""
|
| 53 |
-
|
| 54 |
hf_token = os.environ["HF_TOKEN"]
|
| 55 |
|
| 56 |
|
|
@@ -161,12 +158,11 @@ def load_models(args, master_port, rank):
|
|
| 161 |
assert train_args.model_parallel_size == args.num_gpus
|
| 162 |
if args.ema:
|
| 163 |
print("Loading ema model.")
|
| 164 |
-
ckpt =
|
| 165 |
os.path.join(
|
| 166 |
args.ckpt,
|
| 167 |
-
f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.
|
| 168 |
),
|
| 169 |
-
map_location="cpu",
|
| 170 |
)
|
| 171 |
model.load_state_dict(ckpt, strict=True)
|
| 172 |
|
|
@@ -179,17 +175,15 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
| 179 |
args.precision
|
| 180 |
]
|
| 181 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
| 182 |
-
|
| 183 |
-
print(args)
|
| 184 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 185 |
torch.cuda.set_device(0)
|
| 186 |
-
|
| 187 |
# loading model to gpu
|
| 188 |
# text_encoder = text_encoder.cuda()
|
| 189 |
# vae = vae.cuda()
|
| 190 |
# model = model.to("cuda", dtype=dtype)
|
| 191 |
|
| 192 |
-
with torch.autocast(
|
| 193 |
(
|
| 194 |
cap,
|
| 195 |
resolution,
|
|
@@ -202,18 +196,19 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
| 202 |
proportional_attn,
|
| 203 |
) = infer_args
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
proportional_attn,
|
| 216 |
)
|
|
|
|
|
|
|
| 217 |
try:
|
| 218 |
# begin sampler
|
| 219 |
transport = create_transport(
|
|
@@ -249,7 +244,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
| 249 |
latent_w, latent_h = w // 8, h // 8
|
| 250 |
if int(seed) != 0:
|
| 251 |
torch.random.manual_seed(int(seed))
|
| 252 |
-
z = torch.randn([1, 4, latent_h, latent_w], device=
|
| 253 |
z = z.repeat(2, 1, 1, 1)
|
| 254 |
|
| 255 |
with torch.no_grad():
|
|
@@ -276,13 +271,8 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
| 276 |
ntk_factor=ntk_factor,
|
| 277 |
)
|
| 278 |
|
| 279 |
-
print(
|
| 280 |
-
|
| 281 |
-
print(f"cfg_scale: {cfg_scale}")
|
| 282 |
-
|
| 283 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 284 |
-
print("> [debug] start sample")
|
| 285 |
-
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
|
| 286 |
samples = samples[:1]
|
| 287 |
|
| 288 |
factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
|
|
@@ -294,7 +284,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
| 294 |
|
| 295 |
img = to_pil_image(samples[0].float())
|
| 296 |
|
| 297 |
-
return img
|
| 298 |
except Exception:
|
| 299 |
print(traceback.format_exc())
|
| 300 |
return ModelFailure()
|
|
@@ -505,18 +495,15 @@ def main():
|
|
| 505 |
)
|
| 506 |
with gr.Row():
|
| 507 |
submit_btn = gr.Button("Submit", variant="primary")
|
| 508 |
-
# reset_btn = gr.ClearButton([
|
| 509 |
-
# cap, resolution,
|
| 510 |
-
# num_sampling_steps, cfg_scale, solver,
|
| 511 |
-
# t_shift, seed,
|
| 512 |
-
# ntk_scaling, proportional_attn
|
| 513 |
-
# ])
|
| 514 |
with gr.Column():
|
| 515 |
output_img = gr.Image(
|
| 516 |
label="Lumina Generated image",
|
| 517 |
interactive=False,
|
| 518 |
format="png",
|
|
|
|
| 519 |
)
|
|
|
|
|
|
|
| 520 |
|
| 521 |
with gr.Row():
|
| 522 |
gr.Examples(
|
|
@@ -582,8 +569,8 @@ def main():
|
|
| 582 |
examples_per_page=22,
|
| 583 |
)
|
| 584 |
|
| 585 |
-
@spaces.GPU(duration=
|
| 586 |
-
def on_submit(*infer_args):
|
| 587 |
result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
|
| 588 |
if isinstance(result, ModelFailure):
|
| 589 |
raise RuntimeError("Model failed to generate the image.")
|
|
@@ -602,10 +589,10 @@ def main():
|
|
| 602 |
ntk_scaling,
|
| 603 |
proportional_attn,
|
| 604 |
],
|
| 605 |
-
[output_img],
|
| 606 |
)
|
| 607 |
|
| 608 |
-
demo.queue(
|
| 609 |
|
| 610 |
|
| 611 |
if __name__ == "__main__":
|
|
|
|
| 7 |
shell=True,
|
| 8 |
)
|
| 9 |
|
|
|
|
|
|
|
| 10 |
os.makedirs("/home/user/app/checkpoints", exist_ok=True)
|
| 11 |
+
from huggingface_hub import snapshot_download
|
| 12 |
snapshot_download(
|
| 13 |
repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints"
|
| 14 |
)
|
|
|
|
| 31 |
from torchvision.transforms.functional import to_pil_image
|
| 32 |
|
| 33 |
from PIL import Image
|
| 34 |
+
from safetensors.torch import load_file
|
|
|
|
| 35 |
|
| 36 |
import models
|
| 37 |
|
|
|
|
| 48 |
#### Demo current model: `Lumina-Next-T2I`
|
| 49 |
|
| 50 |
"""
|
|
|
|
| 51 |
hf_token = os.environ["HF_TOKEN"]
|
| 52 |
|
| 53 |
|
|
|
|
| 158 |
assert train_args.model_parallel_size == args.num_gpus
|
| 159 |
if args.ema:
|
| 160 |
print("Loading ema model.")
|
| 161 |
+
ckpt = load_file(
|
| 162 |
os.path.join(
|
| 163 |
args.ckpt,
|
| 164 |
+
f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.safetensors",
|
| 165 |
),
|
|
|
|
| 166 |
)
|
| 167 |
model.load_state_dict(ckpt, strict=True)
|
| 168 |
|
|
|
|
| 175 |
args.precision
|
| 176 |
]
|
| 177 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
| 178 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
| 179 |
torch.cuda.set_device(0)
|
| 180 |
+
|
| 181 |
# loading model to gpu
|
| 182 |
# text_encoder = text_encoder.cuda()
|
| 183 |
# vae = vae.cuda()
|
| 184 |
# model = model.to("cuda", dtype=dtype)
|
| 185 |
|
| 186 |
+
with torch.autocast(device, dtype):
|
| 187 |
(
|
| 188 |
cap,
|
| 189 |
resolution,
|
|
|
|
| 196 |
proportional_attn,
|
| 197 |
) = infer_args
|
| 198 |
|
| 199 |
+
metadata = dict(
|
| 200 |
+
cap=cap,
|
| 201 |
+
resolution=resolution,
|
| 202 |
+
num_sampling_steps=num_sampling_steps,
|
| 203 |
+
cfg_scale=cfg_scale,
|
| 204 |
+
solver=solver,
|
| 205 |
+
t_shift=t_shift,
|
| 206 |
+
seed=seed,
|
| 207 |
+
ntk_scaling=ntk_scaling,
|
| 208 |
+
proportional_attn=proportional_attn,
|
|
|
|
| 209 |
)
|
| 210 |
+
print("> params:", json.dumps(metadata, indent=2))
|
| 211 |
+
|
| 212 |
try:
|
| 213 |
# begin sampler
|
| 214 |
transport = create_transport(
|
|
|
|
| 244 |
latent_w, latent_h = w // 8, h // 8
|
| 245 |
if int(seed) != 0:
|
| 246 |
torch.random.manual_seed(int(seed))
|
| 247 |
+
z = torch.randn([1, 4, latent_h, latent_w], device=device).to(dtype)
|
| 248 |
z = z.repeat(2, 1, 1, 1)
|
| 249 |
|
| 250 |
with torch.no_grad():
|
|
|
|
| 271 |
ntk_factor=ntk_factor,
|
| 272 |
)
|
| 273 |
|
| 274 |
+
print("> start sample")
|
| 275 |
+
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
samples = samples[:1]
|
| 277 |
|
| 278 |
factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
|
|
|
|
| 284 |
|
| 285 |
img = to_pil_image(samples[0].float())
|
| 286 |
|
| 287 |
+
return img, metadata
|
| 288 |
except Exception:
|
| 289 |
print(traceback.format_exc())
|
| 290 |
return ModelFailure()
|
|
|
|
| 495 |
)
|
| 496 |
with gr.Row():
|
| 497 |
submit_btn = gr.Button("Submit", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
with gr.Column():
|
| 499 |
output_img = gr.Image(
|
| 500 |
label="Lumina Generated image",
|
| 501 |
interactive=False,
|
| 502 |
format="png",
|
| 503 |
+
show_label=False
|
| 504 |
)
|
| 505 |
+
with gr.Accordion(label="Generation Parameters", open=False):
|
| 506 |
+
gr_metadata = gr.JSON(label="metadata", show_label=False)
|
| 507 |
|
| 508 |
with gr.Row():
|
| 509 |
gr.Examples(
|
|
|
|
| 569 |
examples_per_page=22,
|
| 570 |
)
|
| 571 |
|
| 572 |
+
@spaces.GPU(duration=80)
|
| 573 |
+
def on_submit(*infer_args, progress=gr.Progress(track_tqdm=True),):
|
| 574 |
result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
|
| 575 |
if isinstance(result, ModelFailure):
|
| 576 |
raise RuntimeError("Model failed to generate the image.")
|
|
|
|
| 589 |
ntk_scaling,
|
| 590 |
proportional_attn,
|
| 591 |
],
|
| 592 |
+
[output_img, gr_metadata],
|
| 593 |
)
|
| 594 |
|
| 595 |
+
demo.queue().launch()
|
| 596 |
|
| 597 |
|
| 598 |
if __name__ == "__main__":
|