Spaces:
Runtime error
Runtime error
| """Show count, mean and median loss per token and label.""" | |
| import streamlit as st | |
| from src.subpages.page import Context, Page | |
| from src.utils import AgGrid, aggrid_interactive_table | |
| def get_loss_by_token(df_tokens): | |
| return ( | |
| df_tokens.groupby("tokens")[["losses"]] | |
| .agg(["count", "mean", "median", "sum"]) | |
| .droplevel(level=0, axis=1) # Get rid of multi-level columns | |
| .sort_values(by="sum", ascending=False) | |
| .reset_index() | |
| ) | |
| def get_loss_by_label(df_tokens): | |
| return ( | |
| df_tokens.groupby("labels")[["losses"]] | |
| .agg(["count", "mean", "median", "sum"]) | |
| .droplevel(level=0, axis=1) | |
| .sort_values(by="mean", ascending=False) | |
| .reset_index() | |
| ) | |
| class LossesPage(Page): | |
| name = "Loss by Token/Label" | |
| icon = "sort-alpha-down" | |
| def render(self, context: Context): | |
| st.title(self.name) | |
| with st.expander("💡", expanded=True): | |
| st.write("Show count, mean and median loss per token and label.") | |
| st.write( | |
| "Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues." | |
| ) | |
| col1, _, col2 = st.columns([8, 1, 6]) | |
| with col1: | |
| st.subheader("💬 Loss by Token") | |
| st.session_state["_merge_tokens"] = st.checkbox( | |
| "Merge tokens", value=True, key="merge_tokens" | |
| ) | |
| loss_by_token = ( | |
| get_loss_by_token(context.df_tokens_merged) | |
| if st.session_state["merge_tokens"] | |
| else get_loss_by_token(context.df_tokens_cleaned) | |
| ) | |
| aggrid_interactive_table(loss_by_token.round(3)) | |
| # st.subheader("🏷️ Loss by Label") | |
| # loss_by_label = get_loss_by_label(df_tokens_cleaned) | |
| # st.dataframe(loss_by_label) | |
| st.write( | |
| "_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._" | |
| ) | |
| with col2: | |
| st.subheader("🏷️ Loss by Label") | |
| loss_by_label = get_loss_by_label(context.df_tokens_cleaned) | |
| AgGrid(loss_by_label.round(3), height=200) | |