Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -241,21 +241,11 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
| 241 |
)
|
| 242 |
# end sampler
|
| 243 |
|
|
|
|
| 244 |
resolution = resolution.split(" ")[-1]
|
| 245 |
w, h = resolution.split("x")
|
| 246 |
w, h = int(w), int(h)
|
| 247 |
|
| 248 |
-
res_cat = (w * h) ** 0.5
|
| 249 |
-
seq_len = res_cat // 16
|
| 250 |
-
|
| 251 |
-
scaling_method = "ntk"
|
| 252 |
-
train_seq_len = 64
|
| 253 |
-
if scaling_method == "ntk":
|
| 254 |
-
scale_factor = seq_len / train_seq_len
|
| 255 |
-
else:
|
| 256 |
-
raise NotImplementedError
|
| 257 |
-
|
| 258 |
-
print(f"> scale factor: {scale_factor}")
|
| 259 |
|
| 260 |
latent_w, latent_h = w // 8, h // 8
|
| 261 |
if int(seed) != 0:
|
|
@@ -284,9 +274,18 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
| 284 |
cap_feats=cap_feats,
|
| 285 |
cap_mask=cap_mask,
|
| 286 |
cfg_scale=cfg_scale,
|
| 287 |
-
scale_factor=scale_factor,
|
| 288 |
)
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
print("> start sample")
|
| 291 |
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
|
| 292 |
samples = samples[:1]
|
|
@@ -511,9 +510,9 @@ def main():
|
|
| 511 |
)
|
| 512 |
with gr.Row():
|
| 513 |
scale_methods = gr.Dropdown(
|
| 514 |
-
value="
|
| 515 |
-
choices=["
|
| 516 |
-
label="
|
| 517 |
)
|
| 518 |
proportional_attn = gr.Checkbox(
|
| 519 |
value=True,
|
|
|
|
| 241 |
)
|
| 242 |
# end sampler
|
| 243 |
|
| 244 |
+
do_extrapolation = "Extrapolation" in resolution
|
| 245 |
resolution = resolution.split(" ")[-1]
|
| 246 |
w, h = resolution.split("x")
|
| 247 |
w, h = int(w), int(h)
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
latent_w, latent_h = w // 8, h // 8
|
| 251 |
if int(seed) != 0:
|
|
|
|
| 274 |
cap_feats=cap_feats,
|
| 275 |
cap_mask=cap_mask,
|
| 276 |
cfg_scale=cfg_scale,
|
|
|
|
| 277 |
)
|
| 278 |
|
| 279 |
+
if proportional_attn:
|
| 280 |
+
model_kwargs["proportional_attn"] = True
|
| 281 |
+
model_kwargs["base_seqlen"] = (train_args.image_size // 16) ** 2
|
| 282 |
+
if do_extrapolation and scaling_method == "Time-aware":
|
| 283 |
+
model_kwargs["scale_factor"] = math.sqrt(w * h / train_args.image_size ** 2)
|
| 284 |
+
else:
|
| 285 |
+
model_kwargs["scale_factor"] = 1.0
|
| 286 |
+
|
| 287 |
+
print(f"> scale factor: {model_kwargs["scale_factor"]}")
|
| 288 |
+
|
| 289 |
print("> start sample")
|
| 290 |
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
|
| 291 |
samples = samples[:1]
|
|
|
|
| 510 |
)
|
| 511 |
with gr.Row():
|
| 512 |
scale_methods = gr.Dropdown(
|
| 513 |
+
value="Time-aware",
|
| 514 |
+
choices=["Time-aware", "None"],
|
| 515 |
+
label="Rope scaling method",
|
| 516 |
)
|
| 517 |
proportional_attn = gr.Checkbox(
|
| 518 |
value=True,
|