Spaces:
Running
Running
| import streamlit as st | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| from pipe import PlonkPipeline | |
| from pathlib import Path | |
| from streamlit_extras.colored_header import colored_header | |
| import plotly.express as px | |
| import requests | |
| from io import BytesIO | |
| # Set page config | |
| st.set_page_config( | |
| page_title="Around the World in 80 Timesteps", page_icon="πΊοΈ", layout="wide" | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| PROJECT_ROOT = Path(__file__).parent.parent.absolute() | |
| # Define checkpoint path | |
| CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints" | |
| MODEL_NAMES = { | |
| "PLONK_OSV_5M": "nicolas-dufour/PLONK_OSV_5M", | |
| "PLONK_YFCC": "nicolas-dufour/PLONK_YFCC", | |
| "PLONK_iNaturalist": "nicolas-dufour/PLONK_iNaturalist", | |
| } | |
| def load_model(model_name): | |
| """Load the model and cache it to prevent reloading""" | |
| try: | |
| pipe = PlonkPipeline(model_path=model_name) | |
| return pipe | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| st.stop() | |
| PIPES = {model_name: load_model(MODEL_NAMES[model_name]) for model_name in MODEL_NAMES} | |
| def predict_location(image, model_name, cfg=0.0, num_samples=256): | |
| with torch.no_grad(): | |
| batch = {"img": [], "emb": []} | |
| # If image is already a PIL Image, use it directly | |
| if isinstance(image, Image.Image): | |
| img = image.convert("RGB") | |
| else: | |
| img = Image.open(image).convert("RGB") | |
| pipe = PIPES[model_name] | |
| # Create a progress bar | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| def update_progress(step, total_steps): | |
| progress = float(step) / float(total_steps) | |
| progress_bar.progress(progress) | |
| status_text.text(f"Sampling step {step + 1}/{total_steps}") | |
| # Get regular predictions with progress updates | |
| predicted_gps = pipe( | |
| img, | |
| batch_size=num_samples, | |
| cfg=cfg, | |
| num_steps=16, | |
| callback=update_progress | |
| ) | |
| # Get single high-confidence prediction | |
| status_text.text("Generating high-confidence prediction...") | |
| high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=16) | |
| # Clear the status text and progress bar | |
| status_text.empty() | |
| progress_bar.empty() | |
| return { | |
| "lat": predicted_gps[:, 0].astype(float).tolist(), | |
| "lon": predicted_gps[:, 1].astype(float).tolist(), | |
| "high_conf_lat": high_conf_gps[0, 0].astype(float), | |
| "high_conf_lon": high_conf_gps[0, 1].astype(float), | |
| } | |
| def load_example_images(): | |
| """Load example images from the examples directory""" | |
| examples_dir = Path(__file__).parent / "examples" | |
| if not examples_dir.exists(): | |
| st.error( | |
| """ | |
| Examples directory not found. Please create the following structure: | |
| demo/ | |
| βββ examples/ | |
| βββ eiffel_tower.jpg | |
| βββ colosseum.jpg | |
| βββ taj_mahal.jpg | |
| βββ statue_liberty.jpg | |
| βββ sydney_opera.jpg | |
| """ | |
| ) | |
| return {} | |
| examples = {} | |
| for img_path in examples_dir.glob("*.jpg"): | |
| # Use filename without extension as the key | |
| name = img_path.stem.replace("_", " ").title() | |
| examples[name] = str(img_path) | |
| if not examples: | |
| st.warning("No example images found in the examples directory.") | |
| return examples | |
| def resize_image_for_display(image, max_size=400): | |
| """Resize image while maintaining aspect ratio""" | |
| # Get current size | |
| width, height = image.size | |
| # Calculate ratio to maintain aspect ratio | |
| if width > height: | |
| if width > max_size: | |
| ratio = max_size / width | |
| new_size = (max_size, int(height * ratio)) | |
| else: | |
| if height > max_size: | |
| ratio = max_size / height | |
| new_size = (int(width * ratio), max_size) | |
| # Only resize if image is larger than max_size | |
| if width > max_size or height > max_size: | |
| return image.resize(new_size, Image.Resampling.LANCZOS) | |
| return image | |
| def load_image_from_url(url): | |
| """Load an image from a URL""" | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() # Raise an exception for bad status codes | |
| return Image.open(BytesIO(response.content)) | |
| except Exception as e: | |
| st.error(f"Error loading image from URL: {str(e)}") | |
| return None | |
| def main(): | |
| # Custom CSS | |
| st.markdown( | |
| """ | |
| <style> | |
| .main { | |
| padding: 0rem 1rem; | |
| } | |
| .stButton>button { | |
| width: 100%; | |
| background-color: #FF4B4B; | |
| color: white; | |
| border: none; | |
| padding: 0.5rem 1rem; | |
| border-radius: 0.5rem; | |
| } | |
| .stButton>button:hover { | |
| background-color: #FF6B6B; | |
| } | |
| .prediction-box { | |
| background-color: #f0f2f6; | |
| padding: 1.5rem; | |
| border-radius: 0.5rem; | |
| margin: 1rem 0; | |
| } | |
| /* New styles for image containers */ | |
| .upload-container { | |
| max-height: 300px; | |
| overflow-y: auto; | |
| margin-bottom: 1rem; | |
| } | |
| .examples-container { | |
| max-height: 200px; | |
| display: flex; | |
| gap: 10px; | |
| } | |
| .stTabs [data-baseweb="tab-panel"] { | |
| padding-top: 1rem; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Header with custom styling | |
| colored_header( | |
| label="πΊοΈ Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation", | |
| description="Upload an image and our model, PLONK, will predict possible locations! In red we will sample one point with guidance scale 2.0 for the best guess. Project page: https://nicolas-dufour.github.io/plonk", | |
| color_name="red-70", | |
| ) | |
| # Adjust column ratio to give 2/3 of the space to the map | |
| col1, col2 = st.columns([1, 2], gap="large") | |
| with col1: | |
| # Add model selection before the sliders | |
| model_name = st.selectbox( | |
| "π€ Select Model", | |
| options=MODEL_NAMES.keys(), | |
| index=0, # Default to OSV | |
| help="Choose which PLONK model variant to use for prediction.", | |
| ) | |
| # Modify the slider columns to accommodate both controls | |
| col_slider1, col_slider2 = st.columns([0.5, 0.5]) | |
| with col_slider1: | |
| cfg_value = st.slider( | |
| "π― Guidance scale", | |
| min_value=0.0, | |
| max_value=5.0, | |
| value=0.0, | |
| step=0.1, | |
| help="Scale for classifier-free guidance during sampling. A small value makes the model predictions display the diversity of the model, while a large value makes the model predictions more conservative but potentially more accurate.", | |
| ) | |
| with col_slider2: | |
| num_samples = st.number_input( | |
| "π² Number of samples", | |
| min_value=1, | |
| max_value=5000, | |
| value=64, | |
| step=1, | |
| help="Number of location predictions to generate. More samples give better coverage but take longer to compute.", | |
| ) | |
| st.markdown("### πΈ Choose your image") | |
| tab1, tab2, tab3 = st.tabs(["Upload", "URL", "Examples"]) | |
| with tab1: | |
| uploaded_file = st.file_uploader( | |
| "Choose an image...", | |
| type=["png", "jpg", "jpeg"], | |
| help="Supported formats: PNG, JPG, JPEG", | |
| ) | |
| if uploaded_file is not None: | |
| st.markdown('<div class="upload-container">', unsafe_allow_html=True) | |
| original_image = Image.open(uploaded_file) | |
| display_image = resize_image_for_display( | |
| original_image.copy(), max_size=300 | |
| ) | |
| st.image( | |
| display_image, caption="Uploaded Image", use_container_width=True | |
| ) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| if st.button("π Predict Location", key="predict_upload"): | |
| predictions = predict_location( | |
| original_image, | |
| model_name=model_name, | |
| cfg=cfg_value, | |
| num_samples=num_samples, | |
| ) | |
| st.session_state["predictions"] = predictions | |
| with tab2: | |
| url = st.text_input("Enter image URL:", key="image_url") | |
| if url: | |
| image = load_image_from_url(url) | |
| if image: | |
| st.markdown( | |
| '<div class="upload-container">', unsafe_allow_html=True | |
| ) | |
| display_image = resize_image_for_display(image.copy(), max_size=300) | |
| st.image( | |
| display_image, | |
| caption="Image from URL", | |
| use_container_width=True, | |
| ) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| if st.button("π Predict Location", key="predict_url"): | |
| predictions = predict_location( | |
| image, | |
| model_name=model_name, | |
| cfg=cfg_value, | |
| num_samples=num_samples, | |
| ) | |
| st.session_state["predictions"] = predictions | |
| with tab3: | |
| examples = load_example_images() | |
| st.markdown('<div class="examples-container">', unsafe_allow_html=True) | |
| example_cols = st.columns(len(examples)) | |
| for idx, (name, path) in enumerate(examples.items()): | |
| with example_cols[idx]: | |
| original_image = Image.open(path) | |
| display_image = resize_image_for_display( | |
| original_image.copy(), max_size=150 | |
| ) | |
| if st.container().button( | |
| "πΈ", | |
| key=f"img_{name}", | |
| help=f"Click to predict location for {name}", | |
| use_container_width=True, | |
| ): | |
| predictions = predict_location( | |
| original_image, | |
| model_name=model_name, | |
| cfg=cfg_value, | |
| num_samples=num_samples, | |
| ) | |
| st.session_state["predictions"] = predictions | |
| st.rerun() | |
| st.image(display_image, caption=name, use_container_width=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| with col2: | |
| st.markdown("### π Predicted Locations") | |
| if "predictions" in st.session_state: | |
| pred = st.session_state["predictions"] | |
| # Create DataFrame for all predictions | |
| df = pd.DataFrame( | |
| { | |
| "lat": pred["lat"], | |
| "lon": pred["lon"], | |
| "type": ["Sample"] * len(pred["lat"]), | |
| } | |
| ) | |
| # Add high-confidence prediction | |
| df = pd.concat( | |
| [ | |
| df, | |
| pd.DataFrame( | |
| { | |
| "lat": [pred["high_conf_lat"]], | |
| "lon": [pred["high_conf_lon"]], | |
| "type": ["Best Guess"], | |
| } | |
| ), | |
| ] | |
| ) | |
| # Create a more interactive map using Plotly | |
| fig = px.scatter_mapbox( | |
| df, | |
| lat="lat", | |
| lon="lon", | |
| zoom=2, | |
| opacity=0.6, | |
| color="type", | |
| color_discrete_map={"Sample": "blue", "Best Guess": "red"}, | |
| mapbox_style="carto-positron", | |
| ) | |
| fig.update_traces(selector=dict(name="Best Guess"), marker_size=15) | |
| fig.update_layout( | |
| margin={"r": 0, "t": 0, "l": 0, "b": 0}, | |
| height=500, | |
| showlegend=True, | |
| legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), | |
| ) | |
| # Display map in a container | |
| with st.container(): | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Display stats in a styled container | |
| with st.container(): | |
| st.markdown( | |
| f""" | |
| <div class="prediction-box"> | |
| <h4>π Prediction Statistics</h4> | |
| <p>Number of sampled locations: {len(pred["lat"])}</p> | |
| <p>Best guess location: {pred["high_conf_lat"]:.2f}Β°, {pred["high_conf_lon"]:.2f}Β°</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| # Empty state with better styling | |
| st.markdown( | |
| """ | |
| <div class="prediction-box" style="text-align: center;"> | |
| <h4>π Upload an image and click 'Predict Location'</h4> | |
| <p>The predicted locations will appear here on an interactive map.</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |