Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Update app.py, score_db.py, and requirements.txt
Browse files- app.py +144 -4
- requirements.txt +5 -0
- score_db.py +143 -0
    	
        app.py
    CHANGED
    
    | @@ -1,7 +1,147 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 |  | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            -
            demo = gr.Interface(fn=greet, inputs="text", outputs="text")
         | 
| 7 | 
            -
            demo.launch()
         | 
|  | |
| 1 | 
            +
            import base64
         | 
| 2 | 
            +
            import io
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            from io import BytesIO
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import matplotlib
         | 
| 7 | 
            +
            matplotlib.use('Agg')
         | 
| 8 | 
            +
            import matplotlib.pyplot as plt
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            from PIL import Image
         | 
| 11 | 
            +
            import requests
         | 
| 12 | 
            +
            from datasets import load_dataset
         | 
| 13 | 
             
            import gradio as gr
         | 
| 14 |  | 
| 15 | 
            +
            from score_db import Battle
         | 
| 16 | 
            +
            from score_db import Model as ModelEnum, Winner
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def make_plot(seismic, predicted_image):
         | 
| 19 | 
            +
                fig, ax = plt.subplots(1, 1, figsize=(10, 10))
         | 
| 20 | 
            +
                ax.imshow(Image.fromarray(seismic), cmap="gray")
         | 
| 21 | 
            +
                ax.imshow(predicted_image, cmap="Reds", alpha=0.5, vmin=0, vmax=1)
         | 
| 22 | 
            +
                ax.set_axis_off()
         | 
| 23 | 
            +
                fig.canvas.draw()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # Create a bytes buffer to save the plot
         | 
| 26 | 
            +
                buf = io.BytesIO()
         | 
| 27 | 
            +
                plt.savefig(buf, format='png', bbox_inches='tight')
         | 
| 28 | 
            +
                buf.seek(0)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # Open the PNG image from the buffer and convert it to a NumPy array
         | 
| 31 | 
            +
                image = np.array(Image.open(buf))
         | 
| 32 | 
            +
                return image
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def call_endpoint(model: ModelEnum, img_array, url: str="https://lukasmosser--seisbase-endpoints-predict.modal.run"):
         | 
| 35 | 
            +
                response = requests.post(url, json={"img": img_array.tolist(), "model": model})
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                if response:
         | 
| 38 | 
            +
                    # Parse the base64-encoded image data
         | 
| 39 | 
            +
                    if response.text.startswith("data:image/tiff;base64,"):
         | 
| 40 | 
            +
                        img_data_out = base64.b64decode(response.text.split(",")[1])
         | 
| 41 | 
            +
                        predicted_image = np.array(Image.open(BytesIO(img_data_out)))
         | 
| 42 | 
            +
                        return predicted_image
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            def select_random_image(dataset):
         | 
| 45 | 
            +
                idx = random.randint(0, len(dataset))
         | 
| 46 | 
            +
                return idx, np.array(dataset[idx]["seismic"])
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            def select_random_models():
         | 
| 49 | 
            +
                model_a = random.choice(list(ModelEnum))
         | 
| 50 | 
            +
                model_b = random.choice(list(ModelEnum))
         | 
| 51 | 
            +
                return model_a, model_b
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            # Create a Gradio interface
         | 
| 55 | 
            +
            with gr.Blocks() as evaluation:
         | 
| 56 | 
            +
                gr.Markdown("""
         | 
| 57 | 
            +
                ## Seismic Fault Detection Model Evaluation
         | 
| 58 | 
            +
                This application allows you to compare the performance of different seismic fault detection models. 
         | 
| 59 | 
            +
                Two models are selected randomly, and their predictions are displayed side by side. 
         | 
| 60 | 
            +
                You can choose the better model or mark it as a tie. The results are recorded and used to update the model ratings.
         | 
| 61 | 
            +
                """)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                battle = gr.State([])
         | 
| 64 | 
            +
                radio = gr.Radio(choices=["Less than 5 years", "5 to 20 years", "more than 20 years"], label="How much experience do you have in seismic fault interpretation?")
         | 
| 65 | 
            +
                with gr.Row():
         | 
| 66 | 
            +
                    output_img1 = gr.Image(label="Model A Image")
         | 
| 67 | 
            +
                    output_img2 = gr.Image(label="Model B Image")
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def show_images():
         | 
| 70 | 
            +
                    dataset = load_dataset("porestar/crossdomainfoundationmodeladaption-deepfault", split="valid")
         | 
| 71 | 
            +
                    idx, image_1 = select_random_image(dataset)
         | 
| 72 | 
            +
                    model_a, model_b = select_random_models()
         | 
| 73 | 
            +
                    fault_probability_1 = call_endpoint(model_a, image_1)
         | 
| 74 | 
            +
                    fault_probability_2 = call_endpoint(model_b, image_1)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    img_1 = make_plot(image_1, fault_probability_1)
         | 
| 77 | 
            +
                    img_2 = make_plot(image_1, fault_probability_2)
         | 
| 78 | 
            +
                    experience = 1 
         | 
| 79 | 
            +
                    if radio.value == "5 to 20 years":
         | 
| 80 | 
            +
                        experience = 2
         | 
| 81 | 
            +
                    elif radio.value == "more than 20 years":
         | 
| 82 | 
            +
                        experience = 3
         | 
| 83 | 
            +
                    battle.value.append(Battle(model_a=model_a, model_b=model_b, winner="tie", judge="None", experience=experience, image_idx=idx))
         | 
| 84 | 
            +
                    return img_1, img_2
         | 
| 85 | 
            +
                
         | 
| 86 | 
            +
                # Define the function to make an API call
         | 
| 87 | 
            +
                def make_api_call(choice: Winner):
         | 
| 88 | 
            +
                    api_url = "https://lukasmosser--seisbase-eval-add-battle.modal.run"
         | 
| 89 | 
            +
                    battle_out = battle.value 
         | 
| 90 | 
            +
                    battle_out[-1].winner = choice
         | 
| 91 | 
            +
                    experience = 1 
         | 
| 92 | 
            +
                    if radio.value == "5 to 20 years":
         | 
| 93 | 
            +
                        experience = 2
         | 
| 94 | 
            +
                    elif radio.value == "more than 20 years":
         | 
| 95 | 
            +
                        experience = 3
         | 
| 96 | 
            +
                    battle_out[-1].experience = experience
         | 
| 97 | 
            +
                    response = requests.post(api_url, json=battle_out[-1].dict())
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # Load images on startup
         | 
| 100 | 
            +
                evaluation.load(show_images, inputs=[], outputs=[output_img1, output_img2])
         | 
| 101 | 
            +
                
         | 
| 102 | 
            +
                with gr.Row():
         | 
| 103 | 
            +
                    btn_winner_a = gr.Button("Winner Model A")
         | 
| 104 | 
            +
                    btn_tie = gr.Button("Tie")
         | 
| 105 | 
            +
                    btn_winner_b = gr.Button("Winner Model B")
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                # Define button click events
         | 
| 108 | 
            +
                btn_winner_a.click(lambda: make_api_call(Winner.model_a), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
         | 
| 109 | 
            +
                btn_tie.click(lambda: make_api_call(Winner.tie), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
         | 
| 110 | 
            +
                btn_winner_b.click(lambda: make_api_call(Winner.model_b), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            with gr.Blocks() as leaderboard:
         | 
| 113 | 
            +
                def get_results():
         | 
| 114 | 
            +
                    response = requests.get("https://lukasmosser--seisbase-eval-compute-ratings.modal.run")
         | 
| 115 | 
            +
                    data = response.json()
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    models = [entry["model"] for entry in data]
         | 
| 118 | 
            +
                    elo_ratings = [entry["elo_rating"] for entry in data]
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    fig, ax = plt.subplots()
         | 
| 121 | 
            +
                    ax.barh(models, elo_ratings, color='skyblue')
         | 
| 122 | 
            +
                    ax.set_xlabel('ELO Rating')
         | 
| 123 | 
            +
                    ax.set_title('Model ELO Ratings')
         | 
| 124 | 
            +
                    plt.tight_layout()
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    fig.canvas.draw()
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    # Create a bytes buffer to save the plot
         | 
| 129 | 
            +
                    buf = io.BytesIO()
         | 
| 130 | 
            +
                    plt.savefig(buf, format='png', bbox_inches='tight')
         | 
| 131 | 
            +
                    buf.seek(0)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # Open the PNG image from the buffer and convert it to a NumPy array
         | 
| 134 | 
            +
                    image = np.array(Image.open(buf))
         | 
| 135 | 
            +
                    return image
         | 
| 136 | 
            +
                
         | 
| 137 | 
            +
                with gr.Row():
         | 
| 138 | 
            +
                    elo_ratings = gr.Image(label="ELO Ratings")
         | 
| 139 | 
            +
                
         | 
| 140 | 
            +
                leaderboard.load(get_results, inputs=[], outputs=[elo_ratings])
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            demo = gr.TabbedInterface([evaluation, leaderboard], ["Arena", "Leaderboard"])
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            # Launch the interface
         | 
| 145 | 
            +
            if __name__ == "__main__":
         | 
| 146 | 
            +
                demo.launch(show_error=True)
         | 
| 147 |  | 
|  | |
|  | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            matplotlib
         | 
| 2 | 
            +
            numpy
         | 
| 3 | 
            +
            gradio
         | 
| 4 | 
            +
            datasets
         | 
| 5 | 
            +
            requests
         | 
    	
        score_db.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import csv
         | 
| 2 | 
            +
            import io
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from datetime import datetime
         | 
| 6 | 
            +
            from enum import Enum
         | 
| 7 | 
            +
            from pathlib import Path
         | 
| 8 | 
            +
            from typing import List
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import pandas as pd
         | 
| 11 | 
            +
            from fastapi import Response
         | 
| 12 | 
            +
            from modal import web_endpoint
         | 
| 13 | 
            +
            import modal
         | 
| 14 | 
            +
            from pydantic import BaseModel
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from rating import compute_mle_elo
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # -----------------------
         | 
| 19 | 
            +
            # Data Model Definition
         | 
| 20 | 
            +
            # -----------------------
         | 
| 21 | 
            +
            class ExperienceEnum(int, Enum):
         | 
| 22 | 
            +
                novice = 1
         | 
| 23 | 
            +
                intermediate = 2
         | 
| 24 | 
            +
                expert = 3
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            class Winner(str, Enum):
         | 
| 27 | 
            +
                model_a = "model_a"
         | 
| 28 | 
            +
                model_b = "model_b"
         | 
| 29 | 
            +
                tie = "tie"
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            class Model(str, Enum):
         | 
| 32 | 
            +
                porestar_deepfault_unet_baseline_1 = "porestar/deepfault-unet-baseline-1"
         | 
| 33 | 
            +
                porestar_deepfault_unet_baseline_2 = "porestar/deepfault-unet-baseline-2"
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            class Battle(BaseModel):
         | 
| 36 | 
            +
                model_a: Model
         | 
| 37 | 
            +
                model_b: Model
         | 
| 38 | 
            +
                winner: Winner
         | 
| 39 | 
            +
                judge: str
         | 
| 40 | 
            +
                image_idx: int 
         | 
| 41 | 
            +
                experience: ExperienceEnum = ExperienceEnum.novice
         | 
| 42 | 
            +
                tstamp: str = str(datetime.now())
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            class EloRating(BaseModel):
         | 
| 45 | 
            +
                model: Model
         | 
| 46 | 
            +
                elo_rating: float
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            # -----------------------
         | 
| 49 | 
            +
            # Modal Configuration
         | 
| 50 | 
            +
            # -----------------------
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            # Create a volume to persist data
         | 
| 53 | 
            +
            data_volume = modal.Volume.from_name("seisbase-data", create_if_missing=True)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            JSON_FILE_PATH = Path("/data/battles.json")
         | 
| 56 | 
            +
            RESULTS_FILE_PATH = Path("/data/ratings.csv")
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            app_image = modal.Image.debian_slim(python_version="3.10").pip_install("pandas", "scikit-learn", "tqdm", "sympy")
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            app = modal.App(
         | 
| 61 | 
            +
                image=app_image,
         | 
| 62 | 
            +
                name="seisbase-eval",
         | 
| 63 | 
            +
                volumes={"/data": data_volume},
         | 
| 64 | 
            +
            )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            def ensure_json_file():
         | 
| 67 | 
            +
                """Ensure the JSON file exists and is initialized with an empty array if necessary."""
         | 
| 68 | 
            +
                if not os.path.exists(JSON_FILE_PATH):
         | 
| 69 | 
            +
                    JSON_FILE_PATH.parent.mkdir(parents=True, exist_ok=True)
         | 
| 70 | 
            +
                    with open(JSON_FILE_PATH, "w") as f:
         | 
| 71 | 
            +
                        json.dump([], f)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            def append_to_json_file(data):
         | 
| 74 | 
            +
                """Append data to the JSON file."""
         | 
| 75 | 
            +
                ensure_json_file()
         | 
| 76 | 
            +
                try:
         | 
| 77 | 
            +
                    with open(JSON_FILE_PATH, "r+") as f:
         | 
| 78 | 
            +
                        try:
         | 
| 79 | 
            +
                            battles = json.load(f)
         | 
| 80 | 
            +
                        except json.JSONDecodeError:
         | 
| 81 | 
            +
                            # Reset the file if corrupted
         | 
| 82 | 
            +
                            battles = []
         | 
| 83 | 
            +
                        battles.append(data)
         | 
| 84 | 
            +
                        f.seek(0)
         | 
| 85 | 
            +
                        json.dump(battles, f, indent=4)
         | 
| 86 | 
            +
                        f.truncate()
         | 
| 87 | 
            +
                except Exception as e:
         | 
| 88 | 
            +
                    raise RuntimeError(f"Failed to append data to JSON file: {e}")
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            def read_json_file():
         | 
| 91 | 
            +
                """Read data from the JSON file."""
         | 
| 92 | 
            +
                ensure_json_file()
         | 
| 93 | 
            +
                try:
         | 
| 94 | 
            +
                    with open(JSON_FILE_PATH, "r") as f:
         | 
| 95 | 
            +
                        try:
         | 
| 96 | 
            +
                            return json.load(f)
         | 
| 97 | 
            +
                        except json.JSONDecodeError:
         | 
| 98 | 
            +
                            return []  # Return an empty list if the file is corrupted
         | 
| 99 | 
            +
                except Exception as e:
         | 
| 100 | 
            +
                    raise RuntimeError(f"Failed to read JSON file: {e}")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            @app.function()
         | 
| 103 | 
            +
            @web_endpoint(method="POST", docs=True)
         | 
| 104 | 
            +
            def add_battle(battle: Battle):
         | 
| 105 | 
            +
                """Add a new battle to the JSON file."""
         | 
| 106 | 
            +
                append_to_json_file(battle.dict())
         | 
| 107 | 
            +
                return {"status": "success", "battle": battle.dict()}
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            @app.function()
         | 
| 111 | 
            +
            @web_endpoint(method="GET", docs=True)
         | 
| 112 | 
            +
            def export_csv():
         | 
| 113 | 
            +
                """Fetch all battles and return as CSV."""
         | 
| 114 | 
            +
                battles = read_json_file()
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                # Create CSV in memory
         | 
| 117 | 
            +
                output = io.StringIO()
         | 
| 118 | 
            +
                writer = csv.DictWriter(output, fieldnames=["model_a", "model_b", "winner", "judge", "imaged_idx", "experience", "tstamp"])
         | 
| 119 | 
            +
                writer.writeheader()
         | 
| 120 | 
            +
                writer.writerows(battles)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                csv_data = output.getvalue()
         | 
| 123 | 
            +
                return Response(content=csv_data, media_type="text/csv")
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            @app.function()
         | 
| 126 | 
            +
            @web_endpoint(method="GET", docs=True)
         | 
| 127 | 
            +
            def compute_ratings() -> List[EloRating]:
         | 
| 128 | 
            +
                """Compute ratings from battles."""
         | 
| 129 | 
            +
                battles = pd.read_json(JSON_FILE_PATH, dtype=[str, str, str, str, int, int, str]).sort_values(ascending=True, by=["tstamp"]).reset_index(drop=True)
         | 
| 130 | 
            +
                elo_mle_ratings = compute_mle_elo(battles)
         | 
| 131 | 
            +
                elo_mle_ratings.to_csv(RESULTS_FILE_PATH)
         | 
| 132 | 
            +
                
         | 
| 133 | 
            +
                df = pd.read_csv(RESULTS_FILE_PATH)
         | 
| 134 | 
            +
                df.columns = ["Model", "Elo rating"]
         | 
| 135 | 
            +
                df = df.sort_values("Elo rating", ascending=False).reset_index(drop=True)
         | 
| 136 | 
            +
                scores = []
         | 
| 137 | 
            +
                for i in range(len(df)):
         | 
| 138 | 
            +
                    scores.append(EloRating(model=df["Model"][i], elo_rating=df["Elo rating"][i]))
         | 
| 139 | 
            +
                return scores
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            @app.local_entrypoint()
         | 
| 142 | 
            +
            def main():
         | 
| 143 | 
            +
                print("Local entrypoint running. Check endpoints for functionality.")
         | 
