Spaces:
Running
Running
Commit
Β·
cf60dfb
1
Parent(s):
99dc3ef
initial commit
Browse files- app.py +45 -31
- models/samplers/riemannian_flow_sampler.py +3 -2
- pipe.py +22 -18
app.py
CHANGED
|
@@ -52,11 +52,32 @@ def predict_location(image, model_name, cfg=0.0, num_samples=256):
|
|
| 52 |
|
| 53 |
pipe = PIPES[model_name]
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# Get single high-confidence prediction
|
|
|
|
| 59 |
high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
return {
|
| 61 |
"lat": predicted_gps[:, 0].astype(float).tolist(),
|
| 62 |
"lon": predicted_gps[:, 1].astype(float).tolist(),
|
|
@@ -234,14 +255,13 @@ def main():
|
|
| 234 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 235 |
|
| 236 |
if st.button("π Predict Location", key="predict_upload"):
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
st.session_state["predictions"] = predictions
|
| 245 |
|
| 246 |
with tab2:
|
| 247 |
url = st.text_input("Enter image URL:", key="image_url")
|
|
@@ -261,16 +281,13 @@ def main():
|
|
| 261 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 262 |
|
| 263 |
if st.button("π Predict Location", key="predict_url"):
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
num_samples=num_samples,
|
| 272 |
-
)
|
| 273 |
-
st.session_state["predictions"] = predictions
|
| 274 |
|
| 275 |
with tab3:
|
| 276 |
examples = load_example_images()
|
|
@@ -290,17 +307,14 @@ def main():
|
|
| 290 |
help=f"Click to predict location for {name}",
|
| 291 |
use_container_width=True,
|
| 292 |
):
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
)
|
| 302 |
-
st.session_state["predictions"] = predictions
|
| 303 |
-
st.rerun()
|
| 304 |
|
| 305 |
st.image(display_image, caption=name, use_container_width=True)
|
| 306 |
st.markdown("</div>", unsafe_allow_html=True)
|
|
|
|
| 52 |
|
| 53 |
pipe = PIPES[model_name]
|
| 54 |
|
| 55 |
+
# Create a progress bar
|
| 56 |
+
progress_bar = st.progress(0)
|
| 57 |
+
status_text = st.empty()
|
| 58 |
+
|
| 59 |
+
def update_progress(step, total_steps):
|
| 60 |
+
progress = float(step) / float(total_steps)
|
| 61 |
+
progress_bar.progress(progress)
|
| 62 |
+
status_text.text(f"Sampling step {step + 1}/{total_steps}")
|
| 63 |
+
|
| 64 |
+
# Get regular predictions with progress updates
|
| 65 |
+
predicted_gps = pipe(
|
| 66 |
+
img,
|
| 67 |
+
batch_size=num_samples,
|
| 68 |
+
cfg=cfg,
|
| 69 |
+
num_steps=16,
|
| 70 |
+
callback=update_progress
|
| 71 |
+
)
|
| 72 |
|
| 73 |
# Get single high-confidence prediction
|
| 74 |
+
status_text.text("Generating high-confidence prediction...")
|
| 75 |
high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=16)
|
| 76 |
+
|
| 77 |
+
# Clear the status text and progress bar
|
| 78 |
+
status_text.empty()
|
| 79 |
+
progress_bar.empty()
|
| 80 |
+
|
| 81 |
return {
|
| 82 |
"lat": predicted_gps[:, 0].astype(float).tolist(),
|
| 83 |
"lon": predicted_gps[:, 1].astype(float).tolist(),
|
|
|
|
| 255 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 256 |
|
| 257 |
if st.button("π Predict Location", key="predict_upload"):
|
| 258 |
+
predictions = predict_location(
|
| 259 |
+
original_image,
|
| 260 |
+
model_name=model_name,
|
| 261 |
+
cfg=cfg_value,
|
| 262 |
+
num_samples=num_samples,
|
| 263 |
+
)
|
| 264 |
+
st.session_state["predictions"] = predictions
|
|
|
|
| 265 |
|
| 266 |
with tab2:
|
| 267 |
url = st.text_input("Enter image URL:", key="image_url")
|
|
|
|
| 281 |
st.markdown("</div>", unsafe_allow_html=True)
|
| 282 |
|
| 283 |
if st.button("π Predict Location", key="predict_url"):
|
| 284 |
+
predictions = predict_location(
|
| 285 |
+
image,
|
| 286 |
+
model_name=model_name,
|
| 287 |
+
cfg=cfg_value,
|
| 288 |
+
num_samples=num_samples,
|
| 289 |
+
)
|
| 290 |
+
st.session_state["predictions"] = predictions
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
with tab3:
|
| 293 |
examples = load_example_images()
|
|
|
|
| 307 |
help=f"Click to predict location for {name}",
|
| 308 |
use_container_width=True,
|
| 309 |
):
|
| 310 |
+
predictions = predict_location(
|
| 311 |
+
original_image,
|
| 312 |
+
model_name=model_name,
|
| 313 |
+
cfg=cfg_value,
|
| 314 |
+
num_samples=num_samples,
|
| 315 |
+
)
|
| 316 |
+
st.session_state["predictions"] = predictions
|
| 317 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
st.image(display_image, caption=name, use_container_width=True)
|
| 320 |
st.markdown("</div>", unsafe_allow_html=True)
|
models/samplers/riemannian_flow_sampler.py
CHANGED
|
@@ -13,6 +13,7 @@ def riemannian_flow_sampler(
|
|
| 13 |
cfg_rate=0,
|
| 14 |
generator=None,
|
| 15 |
return_trajectories=False,
|
|
|
|
| 16 |
):
|
| 17 |
if scheduler is None:
|
| 18 |
raise ValueError("Scheduler must be provided")
|
|
@@ -35,13 +36,13 @@ def riemannian_flow_sampler(
|
|
| 35 |
if cfg_rate > 0 and conditioning_keys is not None:
|
| 36 |
stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0)
|
| 37 |
stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2)
|
| 38 |
-
denoised_all = net(stacked_batch)
|
| 39 |
denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0)
|
| 40 |
denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate
|
| 41 |
else:
|
| 42 |
batch["y"] = x_cur
|
| 43 |
batch["gamma"] = gamma_now.expand(x_cur.shape[0])
|
| 44 |
-
denoised = net(batch)
|
| 45 |
|
| 46 |
dt = gamma_next - gamma_now
|
| 47 |
x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised)
|
|
|
|
| 13 |
cfg_rate=0,
|
| 14 |
generator=None,
|
| 15 |
return_trajectories=False,
|
| 16 |
+
callback=None,
|
| 17 |
):
|
| 18 |
if scheduler is None:
|
| 19 |
raise ValueError("Scheduler must be provided")
|
|
|
|
| 36 |
if cfg_rate > 0 and conditioning_keys is not None:
|
| 37 |
stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0)
|
| 38 |
stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2)
|
| 39 |
+
denoised_all = net(stacked_batch, current_step=step)
|
| 40 |
denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0)
|
| 41 |
denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate
|
| 42 |
else:
|
| 43 |
batch["y"] = x_cur
|
| 44 |
batch["gamma"] = gamma_now.expand(x_cur.shape[0])
|
| 45 |
+
denoised = net(batch, current_step=step)
|
| 46 |
|
| 47 |
dt = gamma_next - gamma_now
|
| 48 |
x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised)
|
pipe.py
CHANGED
|
@@ -216,6 +216,7 @@ class PlonkPipeline:
|
|
| 216 |
scheduler=None,
|
| 217 |
cfg=0,
|
| 218 |
generator=None,
|
|
|
|
| 219 |
):
|
| 220 |
"""Sample from the model given conditioning.
|
| 221 |
|
|
@@ -228,6 +229,7 @@ class PlonkPipeline:
|
|
| 228 |
scheduler: Custom scheduler function (uses default if not provided)
|
| 229 |
cfg: Classifier-free guidance scale (default 15)
|
| 230 |
generator: Random number generator
|
|
|
|
| 231 |
|
| 232 |
Returns:
|
| 233 |
Sampled GPS coordinates after postprocessing
|
|
@@ -264,26 +266,28 @@ class PlonkPipeline:
|
|
| 264 |
sampler = self.sampler
|
| 265 |
if scheduler is None:
|
| 266 |
scheduler = self.scheduler
|
|
|
|
| 267 |
# Sample from model
|
| 268 |
if num_steps is None:
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
|
|
|
| 287 |
|
| 288 |
# Apply postprocessing and return
|
| 289 |
output = self.postprocessing(output)
|
|
|
|
| 216 |
scheduler=None,
|
| 217 |
cfg=0,
|
| 218 |
generator=None,
|
| 219 |
+
callback=None,
|
| 220 |
):
|
| 221 |
"""Sample from the model given conditioning.
|
| 222 |
|
|
|
|
| 229 |
scheduler: Custom scheduler function (uses default if not provided)
|
| 230 |
cfg: Classifier-free guidance scale (default 15)
|
| 231 |
generator: Random number generator
|
| 232 |
+
callback: Optional callback function to report progress (step, total_steps)
|
| 233 |
|
| 234 |
Returns:
|
| 235 |
Sampled GPS coordinates after postprocessing
|
|
|
|
| 266 |
sampler = self.sampler
|
| 267 |
if scheduler is None:
|
| 268 |
scheduler = self.scheduler
|
| 269 |
+
|
| 270 |
# Sample from model
|
| 271 |
if num_steps is None:
|
| 272 |
+
num_steps = 16 # Default number of steps
|
| 273 |
+
|
| 274 |
+
# Create a wrapper for the model that updates progress
|
| 275 |
+
def model_with_progress(*args, **kwargs):
|
| 276 |
+
step = kwargs.pop('current_step', 0)
|
| 277 |
+
if callback:
|
| 278 |
+
callback(step, num_steps)
|
| 279 |
+
return self.model(*args, **kwargs)
|
| 280 |
+
|
| 281 |
+
output = sampler(
|
| 282 |
+
model_with_progress,
|
| 283 |
+
batch,
|
| 284 |
+
conditioning_keys="emb",
|
| 285 |
+
scheduler=scheduler,
|
| 286 |
+
num_steps=num_steps,
|
| 287 |
+
cfg_rate=cfg,
|
| 288 |
+
generator=generator,
|
| 289 |
+
callback=callback,
|
| 290 |
+
)
|
| 291 |
|
| 292 |
# Apply postprocessing and return
|
| 293 |
output = self.postprocessing(output)
|