Spaces:
Runtime error
Runtime error
| """This page contains all misclassified examples and allows filtering by specific error types.""" | |
| from collections import defaultdict | |
| import pandas as pd | |
| import streamlit as st | |
| from sklearn.metrics import confusion_matrix | |
| from src.subpages.page import Context, Page | |
| from src.utils import htmlify_labeled_example | |
| class MisclassifiedPage(Page): | |
| name = "Misclassified" | |
| icon = "x-octagon" | |
| def render(self, context: Context): | |
| st.title(self.name) | |
| with st.expander("💡", expanded=True): | |
| st.write( | |
| "This page contains all misclassified examples and allows filtering by specific error types." | |
| ) | |
| misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique() | |
| misclassified_samples = context.df_tokens_merged.loc[misclassified_indices] | |
| cm = confusion_matrix( | |
| misclassified_samples.labels, | |
| misclassified_samples.preds, | |
| labels=context.labels, | |
| ) | |
| # st.pyplot( | |
| # plot_confusion_matrix( | |
| # y_preds=misclassified_samples["preds"], | |
| # y_true=misclassified_samples["labels"], | |
| # labels=labels, | |
| # normalize=None, | |
| # zero_diagonal=True, | |
| # ), | |
| # ) | |
| df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str) | |
| import numpy as np | |
| np.fill_diagonal(df.values, "") | |
| st.dataframe(df.applymap(lambda x: x if x != "0" else "")) | |
| # import matplotlib.pyplot as plt | |
| # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html()) | |
| # selection = aggrid_interactive_table(df) | |
| # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True) | |
| confusions = defaultdict(int) | |
| for i, row in enumerate(cm): | |
| for j, _ in enumerate(row): | |
| if i == j or cm[i][j] == 0: | |
| continue | |
| confusions[(context.labels[i], context.labels[j])] += cm[i][j] | |
| def format_func(item): | |
| return ( | |
| f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All" | |
| ) | |
| conf = st.radio( | |
| "Filter by Class Confusion", | |
| options=list(zip(confusions.keys(), confusions.values())), | |
| format_func=format_func, | |
| ) | |
| # st.write( | |
| # f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`" | |
| # ) | |
| filtered_indices = misclassified_samples.query( | |
| f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'" | |
| ).index | |
| for i, idx in enumerate(filtered_indices): | |
| sample = context.df_tokens_merged.loc[idx] | |
| st.write( | |
| htmlify_labeled_example(sample), | |
| unsafe_allow_html=True, | |
| ) | |
| st.write("---") | |