Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import scipy.stats | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| from datasets import load_dataset | |
| from huggingface_hub import HfApi | |
| # Set up logging | |
| logger = logging.getLogger("app") | |
| logger.setLevel(logging.INFO) | |
| formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| ch = logging.StreamHandler() | |
| ch.setFormatter(formatter) | |
| logger.addHandler(ch) | |
| # Disable the absl logger (annoying) | |
| logging.getLogger("absl").setLevel(logging.WARNING) | |
| API = HfApi(token=os.environ.get("TOKEN")) | |
| RESULTS_REPO = "open-rl-leaderboard/results_v2" | |
| REFRESH_RATE = 5 * 60 # 5 minutes | |
| ALL_ENV_IDS = { | |
| "Atari": [ | |
| "AdventureNoFrameskip-v4", | |
| "AirRaidNoFrameskip-v4", | |
| "AlienNoFrameskip-v4", | |
| "AmidarNoFrameskip-v4", | |
| "AssaultNoFrameskip-v4", | |
| "AsterixNoFrameskip-v4", | |
| "AsteroidsNoFrameskip-v4", | |
| "AtlantisNoFrameskip-v4", | |
| "BankHeistNoFrameskip-v4", | |
| "BattleZoneNoFrameskip-v4", | |
| "BeamRiderNoFrameskip-v4", | |
| "BerzerkNoFrameskip-v4", | |
| "BowlingNoFrameskip-v4", | |
| "BoxingNoFrameskip-v4", | |
| "BreakoutNoFrameskip-v4", | |
| "CarnivalNoFrameskip-v4", | |
| "CentipedeNoFrameskip-v4", | |
| "ChopperCommandNoFrameskip-v4", | |
| "CrazyClimberNoFrameskip-v4", | |
| "DefenderNoFrameskip-v4", | |
| "DemonAttackNoFrameskip-v4", | |
| "DoubleDunkNoFrameskip-v4", | |
| "ElevatorActionNoFrameskip-v4", | |
| "EnduroNoFrameskip-v4", | |
| "FishingDerbyNoFrameskip-v4", | |
| "FreewayNoFrameskip-v4", | |
| "FrostbiteNoFrameskip-v4", | |
| "GopherNoFrameskip-v4", | |
| "GravitarNoFrameskip-v4", | |
| "HeroNoFrameskip-v4", | |
| "IceHockeyNoFrameskip-v4", | |
| "JamesbondNoFrameskip-v4", | |
| "JourneyEscapeNoFrameskip-v4", | |
| "KangarooNoFrameskip-v4", | |
| "KrullNoFrameskip-v4", | |
| "KungFuMasterNoFrameskip-v4", | |
| "MontezumaRevengeNoFrameskip-v4", | |
| "MsPacmanNoFrameskip-v4", | |
| "NameThisGameNoFrameskip-v4", | |
| "PhoenixNoFrameskip-v4", | |
| "PitfallNoFrameskip-v4", | |
| "PongNoFrameskip-v4", | |
| "PooyanNoFrameskip-v4", | |
| "PrivateEyeNoFrameskip-v4", | |
| "QbertNoFrameskip-v4", | |
| "RiverraidNoFrameskip-v4", | |
| "RoadRunnerNoFrameskip-v4", | |
| "RobotankNoFrameskip-v4", | |
| "SeaquestNoFrameskip-v4", | |
| "SkiingNoFrameskip-v4", | |
| "SolarisNoFrameskip-v4", | |
| "SpaceInvadersNoFrameskip-v4", | |
| "StarGunnerNoFrameskip-v4", | |
| "TennisNoFrameskip-v4", | |
| "TimePilotNoFrameskip-v4", | |
| "TutankhamNoFrameskip-v4", | |
| "UpNDownNoFrameskip-v4", | |
| "VentureNoFrameskip-v4", | |
| "VideoPinballNoFrameskip-v4", | |
| "WizardOfWorNoFrameskip-v4", | |
| "YarsRevengeNoFrameskip-v4", | |
| "ZaxxonNoFrameskip-v4", | |
| ], | |
| "Box2D": [ | |
| "BipedalWalker-v3", | |
| "BipedalWalkerHardcore-v3", | |
| "CarRacing-v2", | |
| "LunarLander-v2", | |
| "LunarLanderContinuous-v2", | |
| ], | |
| "Toy text": [ | |
| "Blackjack-v1", | |
| "CliffWalking-v0", | |
| "FrozenLake-v1", | |
| "FrozenLake8x8-v1", | |
| ], | |
| "Classic control": [ | |
| "Acrobot-v1", | |
| "CartPole-v1", | |
| "MountainCar-v0", | |
| "MountainCarContinuous-v0", | |
| "Pendulum-v1", | |
| ], | |
| "MuJoCo": [ | |
| "Ant-v4", | |
| "HalfCheetah-v4", | |
| "Hopper-v4", | |
| "Humanoid-v4", | |
| "HumanoidStandup-v4", | |
| "InvertedDoublePendulum-v4", | |
| "InvertedPendulum-v4", | |
| "Pusher-v4", | |
| "Reacher-v4", | |
| "Swimmer-v4", | |
| "Walker2d-v4", | |
| ], | |
| "PyBullet": [ | |
| "AntBulletEnv-v0", | |
| "HalfCheetahBulletEnv-v0", | |
| "HopperBulletEnv-v0", | |
| "HumanoidBulletEnv-v0", | |
| "InvertedDoublePendulumBulletEnv-v0", | |
| "InvertedPendulumSwingupBulletEnv-v0", | |
| "MinitaurBulletEnv-v0", | |
| "ReacherBulletEnv-v0", | |
| "Walker2DBulletEnv-v0", | |
| ], | |
| } | |
| def iqm(x): | |
| return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) | |
| def get_leaderboard_df(): | |
| logger.info("Downloading results") | |
| dataset = load_dataset(RESULTS_REPO, split="train") # split is not important, but we need to use "train") | |
| df = dataset.to_pandas() # convert to pandas dataframe | |
| df = df[df["status"] == "DONE"] # keep only the models that are done | |
| df["iqm_episodic_return"] = df["episodic_returns"].apply(iqm) | |
| logger.debug("Results downloaded") | |
| return df | |
| def select_env(df: pd.DataFrame, env_id: str): | |
| df = df[df["env_id"] == env_id] | |
| df = df.sort_values("iqm_episodic_return", ascending=False) | |
| df["ranking"] = np.arange(1, len(df) + 1) | |
| return df | |
| def format_df(df: pd.DataFrame): | |
| # Add hyperlinks | |
| df = df.copy() | |
| for index, row in df.iterrows(): | |
| user_id = row["user_id"] | |
| model_id = row["model_id"] | |
| df.loc[index, "user_id"] = f"[{user_id}](https://huggingface.co/{user_id})" | |
| df.loc[index, "model_id"] = f"[{model_id}](https://huggingface.co/{user_id}/{model_id})" | |
| # Keep only the relevant columns | |
| df = df[["ranking", "user_id", "model_id", "iqm_episodic_return"]] | |
| return df.values.tolist() | |
| def refresh_video(df, env_id): | |
| env_df = select_env(df, env_id) | |
| if not env_df.empty: | |
| user_id = env_df.iloc[0]["user_id"] | |
| model_id = env_df.iloc[0]["model_id"] | |
| sha = env_df.iloc[0]["sha"] | |
| repo_id = f"{user_id}/{model_id}" | |
| try: | |
| video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=sha, repo_type="model") | |
| return video_path | |
| except Exception as e: | |
| logger.error(f"Error while downloading video for {env_id}: {e}") | |
| return None | |
| else: | |
| return None | |
| def refresh_one_video(df, env_id): | |
| def inner(): | |
| return refresh_video(df, env_id) | |
| return inner | |
| def refresh_winner(df, env_id): | |
| # print("Refreshing winners") | |
| env_df = select_env(df, env_id) | |
| if not env_df.empty: | |
| user_id = env_df.iloc[0]["user_id"] | |
| model_id = env_df.iloc[0]["model_id"] | |
| url = f"https://huggingface.co/{user_id}/{model_id}" | |
| return f"""## {env_id} | |
| ### 🏆 [Best model]({url}) 🏆""" | |
| else: | |
| return f"""## {env_id} | |
| This leaderboard is quite empty... 😢 | |
| Be the first to submit your model! | |
| Check the tab "🚀 Getting my agent evaluated" | |
| """ | |
| def refresh_num_models(df): | |
| return f"The leaderboard currently contains {len(df):,} models." | |
| css = """ | |
| .generating { | |
| border: none; | |
| } | |
| h2 { | |
| text-align: center; | |
| } | |
| h3 { | |
| text-align: center; | |
| } | |
| """ | |
| def update_globals(): | |
| global dataframes, winner_texts, video_pathes, num_models_str, df | |
| df = get_leaderboard_df() | |
| all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids] | |
| dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids} | |
| winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids} | |
| video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids} | |
| num_models_str = refresh_num_models(df) | |
| update_globals() | |
| def refresh(): | |
| global dataframes, winner_texts, num_models_str | |
| return list(dataframes.values()) + list(winner_texts.values()) + [num_models_str] | |
| with gr.Blocks(css=css) as demo: | |
| with open("texts/heading.md") as fp: | |
| gr.Markdown(fp.read()) | |
| num_models_md = gr.Markdown() | |
| with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
| with gr.TabItem("🏅 Leaderboard"): | |
| all_gr_dfs = {} | |
| all_gr_winners = {} | |
| all_gr_videos = {} | |
| for env_domain, env_ids in ALL_ENV_IDS.items(): | |
| with gr.TabItem(env_domain): | |
| for env_id in env_ids: | |
| # If the env_id envs with "NoFrameskip-v4", we remove it to improve readability | |
| tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id | |
| with gr.TabItem(tab_env_id) as tab: | |
| logger.debug(f"Creating tab for {env_id}") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=3): | |
| gr_df = gr.components.Dataframe( | |
| headers=["🏆", "🧑 User", "🤖 Model id", "📊 IQM episodic return"], | |
| datatype=["number", "markdown", "markdown", "number"], | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(): # Display the env_id and the winner | |
| gr_winner = gr.Markdown() | |
| with gr.Row(): # Play the video of the best model | |
| gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689, | |
| min_width=50, | |
| show_download_button=False, | |
| show_share_button=False, | |
| show_label=False, | |
| interactive=False, | |
| ) | |
| all_gr_dfs[env_id] = gr_df | |
| all_gr_winners[env_id] = gr_winner | |
| all_gr_videos[env_id] = gr_video | |
| tab.select(refresh_one_video(df, env_id), outputs=[gr_video]) | |
| # Load the first video of the first environment | |
| demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]]) | |
| with gr.TabItem("🚀 Getting my agent evaluated"): | |
| with open("texts/getting_my_agent_evaluated.md") as fp: | |
| gr.Markdown(fp.read()) | |
| with gr.TabItem("📝 About"): | |
| with open("texts/about.md") as fp: | |
| gr.Markdown(fp.read()) | |
| demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md]) | |
| scheduler = BackgroundScheduler() | |
| scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1) | |
| scheduler.start() | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |