Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -117,13 +117,14 @@ def load_models(args, master_port, rank):
|
|
| 117 |
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
| 118 |
args.precision
|
| 119 |
]
|
|
|
|
| 120 |
|
| 121 |
print(f"Creating lm: Gemma-2B")
|
| 122 |
text_encoder = (
|
| 123 |
AutoModelForCausalLM.from_pretrained(
|
| 124 |
"google/gemma-2b",
|
| 125 |
torch_dtype=dtype,
|
| 126 |
-
device_map=
|
| 127 |
# device_map="cuda",
|
| 128 |
token=hf_token,
|
| 129 |
)
|
|
@@ -146,7 +147,7 @@ def load_models(args, master_port, rank):
|
|
| 146 |
vae = AutoencoderKL.from_pretrained(
|
| 147 |
"stabilityai/sdxl-vae",
|
| 148 |
torch_dtype=torch.float32,
|
| 149 |
-
)
|
| 150 |
|
| 151 |
print(f"Creating DiT: Next-DiT")
|
| 152 |
# latent_size = train_args.image_size // 8
|
|
@@ -155,7 +156,7 @@ def load_models(args, master_port, rank):
|
|
| 155 |
cap_feat_dim=cap_feat_dim,
|
| 156 |
)
|
| 157 |
# model.eval().to("cuda", dtype=dtype)
|
| 158 |
-
model.eval()
|
| 159 |
|
| 160 |
assert train_args.model_parallel_size == args.num_gpus
|
| 161 |
if args.ema:
|
|
@@ -169,7 +170,6 @@ def load_models(args, master_port, rank):
|
|
| 169 |
)
|
| 170 |
model.load_state_dict(ckpt, strict=True)
|
| 171 |
|
| 172 |
-
# barrier.wait()
|
| 173 |
return text_encoder, tokenizer, vae, model
|
| 174 |
|
| 175 |
|
|
@@ -181,12 +181,13 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
| 181 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
| 182 |
|
| 183 |
print(args)
|
|
|
|
| 184 |
torch.cuda.set_device(0)
|
| 185 |
-
|
| 186 |
# loading model to gpu
|
| 187 |
-
text_encoder = text_encoder.cuda()
|
| 188 |
-
vae = vae.cuda()
|
| 189 |
-
model = model.to("cuda", dtype=dtype)
|
| 190 |
|
| 191 |
with torch.autocast("cuda", dtype):
|
| 192 |
(
|
|
@@ -581,7 +582,7 @@ def main():
|
|
| 581 |
examples_per_page=22,
|
| 582 |
)
|
| 583 |
|
| 584 |
-
@spaces.GPU(duration=
|
| 585 |
def on_submit(*infer_args):
|
| 586 |
result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
|
| 587 |
if isinstance(result, ModelFailure):
|
|
|
|
| 117 |
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
| 118 |
args.precision
|
| 119 |
]
|
| 120 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 121 |
|
| 122 |
print(f"Creating lm: Gemma-2B")
|
| 123 |
text_encoder = (
|
| 124 |
AutoModelForCausalLM.from_pretrained(
|
| 125 |
"google/gemma-2b",
|
| 126 |
torch_dtype=dtype,
|
| 127 |
+
device_map=device,
|
| 128 |
# device_map="cuda",
|
| 129 |
token=hf_token,
|
| 130 |
)
|
|
|
|
| 147 |
vae = AutoencoderKL.from_pretrained(
|
| 148 |
"stabilityai/sdxl-vae",
|
| 149 |
torch_dtype=torch.float32,
|
| 150 |
+
).to(device)
|
| 151 |
|
| 152 |
print(f"Creating DiT: Next-DiT")
|
| 153 |
# latent_size = train_args.image_size // 8
|
|
|
|
| 156 |
cap_feat_dim=cap_feat_dim,
|
| 157 |
)
|
| 158 |
# model.eval().to("cuda", dtype=dtype)
|
| 159 |
+
model.eval().to(device, dtype=dtype)
|
| 160 |
|
| 161 |
assert train_args.model_parallel_size == args.num_gpus
|
| 162 |
if args.ema:
|
|
|
|
| 170 |
)
|
| 171 |
model.load_state_dict(ckpt, strict=True)
|
| 172 |
|
|
|
|
| 173 |
return text_encoder, tokenizer, vae, model
|
| 174 |
|
| 175 |
|
|
|
|
| 181 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
| 182 |
|
| 183 |
print(args)
|
| 184 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 185 |
torch.cuda.set_device(0)
|
| 186 |
+
|
| 187 |
# loading model to gpu
|
| 188 |
+
# text_encoder = text_encoder.cuda()
|
| 189 |
+
# vae = vae.cuda()
|
| 190 |
+
# model = model.to("cuda", dtype=dtype)
|
| 191 |
|
| 192 |
with torch.autocast("cuda", dtype):
|
| 193 |
(
|
|
|
|
| 582 |
examples_per_page=22,
|
| 583 |
)
|
| 584 |
|
| 585 |
+
@spaces.GPU(duration=200)
|
| 586 |
def on_submit(*infer_args):
|
| 587 |
result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
|
| 588 |
if isinstance(result, ModelFailure):
|