Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
import subprocess
|
| 2 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import argparse
|
| 6 |
import builtins
|
|
@@ -151,8 +154,6 @@ def load_model(args, master_port, rank):
|
|
| 151 |
assert train_args.model_parallel_size == args.num_gpus
|
| 152 |
if args.ema:
|
| 153 |
print("Loading ema model.")
|
| 154 |
-
|
| 155 |
-
subprocess.run("huggingface-cli download --resume-download Alpha-VLLM/Lumina-Next-T2I --local-dir ./checkpoints --local-dir-use-symlinks False", shell=True)
|
| 156 |
ckpt = torch.load(
|
| 157 |
os.path.join(
|
| 158 |
args.ckpt,
|
|
@@ -166,13 +167,15 @@ def load_model(args, master_port, rank):
|
|
| 166 |
return text_encoder, tokenizer, vae, model
|
| 167 |
|
| 168 |
|
|
|
|
| 169 |
@torch.no_grad()
|
| 170 |
def model_main(args, master_port, rank, request_queue, response_queue, text_encoder, tokenizer, vae, model):
|
| 171 |
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
| 172 |
args.precision
|
| 173 |
]
|
| 174 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
| 175 |
-
|
|
|
|
| 176 |
with torch.autocast("cuda", dtype):
|
| 177 |
# barrier.wait()
|
| 178 |
while True:
|
|
@@ -407,7 +410,6 @@ def find_free_port() -> int:
|
|
| 407 |
return port
|
| 408 |
|
| 409 |
|
| 410 |
-
@spaces.GPU
|
| 411 |
def main():
|
| 412 |
parser = argparse.ArgumentParser()
|
| 413 |
mode = "ODE"
|
|
@@ -439,7 +441,6 @@ def main():
|
|
| 439 |
# mp_barrier = mp.Barrier(args.num_gpus + 1)
|
| 440 |
# barrier = Barrier(args.num_gpus + 1)
|
| 441 |
for i in range(args.num_gpus):
|
| 442 |
-
text_encoder, tokenizer, vae, model = load_model(args, master_port, i)
|
| 443 |
request_queues.append(Queue())
|
| 444 |
generation_kwargs = dict(
|
| 445 |
args=args,
|
|
@@ -447,10 +448,6 @@ def main():
|
|
| 447 |
rank=i,
|
| 448 |
request_queue=request_queues[i],
|
| 449 |
response_queue=response_queue if i == 0 else None,
|
| 450 |
-
text_encoder=text_encoder,
|
| 451 |
-
tokenizer=tokenizer,
|
| 452 |
-
vae=vae,
|
| 453 |
-
model=model
|
| 454 |
)
|
| 455 |
model_main(**generation_kwargs)
|
| 456 |
# thread = Thread(target=model_main, kwargs=generation_kwargs)
|
|
|
|
| 1 |
import subprocess
|
| 2 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 3 |
+
|
| 4 |
+
from huggingface_hub import snapshot_download
|
| 5 |
+
os.makedirs("/home/user/app/checkpoints", exist_ok=True)
|
| 6 |
+
snapshot_download(repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints")
|
| 7 |
|
| 8 |
import argparse
|
| 9 |
import builtins
|
|
|
|
| 154 |
assert train_args.model_parallel_size == args.num_gpus
|
| 155 |
if args.ema:
|
| 156 |
print("Loading ema model.")
|
|
|
|
|
|
|
| 157 |
ckpt = torch.load(
|
| 158 |
os.path.join(
|
| 159 |
args.ckpt,
|
|
|
|
| 167 |
return text_encoder, tokenizer, vae, model
|
| 168 |
|
| 169 |
|
| 170 |
+
@spaces.GPU(duration=80)
|
| 171 |
@torch.no_grad()
|
| 172 |
def model_main(args, master_port, rank, request_queue, response_queue, text_encoder, tokenizer, vae, model):
|
| 173 |
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
| 174 |
args.precision
|
| 175 |
]
|
| 176 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
| 177 |
+
text_encoder, tokenizer, vae, model = load_model(args, master_port, rank)
|
| 178 |
+
|
| 179 |
with torch.autocast("cuda", dtype):
|
| 180 |
# barrier.wait()
|
| 181 |
while True:
|
|
|
|
| 410 |
return port
|
| 411 |
|
| 412 |
|
|
|
|
| 413 |
def main():
|
| 414 |
parser = argparse.ArgumentParser()
|
| 415 |
mode = "ODE"
|
|
|
|
| 441 |
# mp_barrier = mp.Barrier(args.num_gpus + 1)
|
| 442 |
# barrier = Barrier(args.num_gpus + 1)
|
| 443 |
for i in range(args.num_gpus):
|
|
|
|
| 444 |
request_queues.append(Queue())
|
| 445 |
generation_kwargs = dict(
|
| 446 |
args=args,
|
|
|
|
| 448 |
rank=i,
|
| 449 |
request_queue=request_queues[i],
|
| 450 |
response_queue=response_queue if i == 0 else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
)
|
| 452 |
model_main(**generation_kwargs)
|
| 453 |
# thread = Thread(target=model_main, kwargs=generation_kwargs)
|