Spaces:
Running
Running
| import pandas as pd | |
| import streamlit as st | |
| # from annotated_text import annotated_text | |
| from annotated_text.util import get_annotated_html | |
| from streamlit_annotation_tools import text_labeler | |
| from constants import ( | |
| APP_INTRO, | |
| APP_TITLE, | |
| EVAL_FUNCTION_INTRO, | |
| EVAL_FUNCTION_PROPERTIES, | |
| NER_TASK_EXPLAINER, | |
| PREDICTION_ADDITION_INSTRUCTION, | |
| SPAN_BASED_METRICS_EXPLANATION, | |
| TOKEN_BASED_METRICS_EXPLANATION, | |
| ) | |
| from evaluation_metrics import EVALUATION_METRICS | |
| from predefined_example import EXAMPLES | |
| from span_dataclass_converters import ( | |
| get_highlight_spans_from_ner_spans, | |
| get_ner_spans_from_annotations, | |
| ) | |
| def get_examples_attributes(selected_example): | |
| "Return example attributes so that they are not refreshed on every interaction" | |
| return ( | |
| selected_example.text, | |
| selected_example.gt_labels, | |
| selected_example.gt_spans, | |
| selected_example.predictions, | |
| selected_example.tags, | |
| ) | |
| if __name__ == "__main__": | |
| st.set_page_config(layout="wide") | |
| # st.title(APP_TITLE) | |
| st.markdown( | |
| f"<h1 style='text-align: center; color: grey;'>{APP_TITLE}</h1>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.write(APP_INTRO) | |
| explanation_tab, comparision_tab = st.tabs(["📙 Explanation", "⚖️ Comparison"]) | |
| with explanation_tab: | |
| st.write(EVAL_FUNCTION_INTRO) | |
| st.image("assets/eval_fnc_viz.png", caption="Evaluation Function Flow") | |
| st.markdown(EVAL_FUNCTION_PROPERTIES) | |
| st.markdown(NER_TASK_EXPLAINER) | |
| st.subheader("Evaluation Metrics") | |
| metric_names = "\n".join( | |
| [ | |
| f"{index+1}. " + evaluation_metric.name | |
| for index, evaluation_metric in enumerate(EVALUATION_METRICS) | |
| ] | |
| ) | |
| st.markdown( | |
| "The different evaluation metrics we have for the NER task are\n" | |
| "\n" | |
| f"{metric_names}" | |
| ) | |
| st.markdown( | |
| "These metrics can be broadly classified as 'Span Based' and 'Token Based' metrics." | |
| ) | |
| st.markdown("### Span Based Metrics") | |
| st.markdown(SPAN_BASED_METRICS_EXPLANATION) | |
| st.markdown("### Token Based Metrics") | |
| st.markdown(TOKEN_BASED_METRICS_EXPLANATION) | |
| st.divider() | |
| st.markdown( | |
| "Now that you have read the basics of the metrics calculation, head to the comparision section to try out some examples!" | |
| ) | |
| with comparision_tab: | |
| # with st.container(): | |
| st.subheader("Ground Truth & Predictions") # , divider='rainbow') | |
| selected_example = st.selectbox( | |
| "Select an example text from the drop down below", | |
| [example for example in EXAMPLES], | |
| format_func=lambda ex: ex.text, | |
| ) | |
| text, gt_labels, gt_spans, predictions, tags = get_examples_attributes( | |
| selected_example | |
| ) | |
| # annotated_text( | |
| # get_highlight_spans_from_ner_spans( | |
| # get_ner_spans_from_annotations(gt_labels), text | |
| # ) | |
| # ) | |
| annotated_predictions = [ | |
| get_annotated_html(get_highlight_spans_from_ner_spans(ner_span, text)) | |
| for ner_span in predictions | |
| ] | |
| predictions_df = pd.DataFrame( | |
| { | |
| # "ID": [f"Prediction_{index}" for index in range(len(predictions))], | |
| "Prediction": annotated_predictions, | |
| "ner_spans": predictions, | |
| }, | |
| index=["Ground Truth"] | |
| + [f"Prediction_{index}" for index in range(len(predictions) - 1)], | |
| ) | |
| # st.subheader("Predictions") # , divider='rainbow') | |
| with st.expander("Click to Add Predictions"): | |
| st.subheader("Adding predictions") | |
| st.markdown(PREDICTION_ADDITION_INSTRUCTION) | |
| st.write( | |
| "Note: Only the spans of the selected label name are shown at a given instance. Click on the label to see the corresponding spans. (or view the json below)", | |
| ) | |
| labels = text_labeler(text, gt_labels) | |
| st.json(labels, expanded=False) | |
| # if st.button("Add Prediction"): | |
| # labels = text_labeler(text) | |
| if st.button("Add!"): | |
| spans = get_ner_spans_from_annotations(labels) | |
| spans = sorted(spans, key=lambda span: span["start"]) | |
| predictions.append(spans) | |
| annotated_predictions.append( | |
| get_annotated_html(get_highlight_spans_from_ner_spans(spans, text)) | |
| ) | |
| predictions_df = pd.DataFrame( | |
| { | |
| # "ID": [f"Prediction_{index}" for index in range(len(predictions))], | |
| "Prediction": annotated_predictions, | |
| "ner_spans": predictions, | |
| }, | |
| index=["Ground Truth"] | |
| + [f"Prediction_{index}" for index in range(len(predictions) - 1)], | |
| ) | |
| print("added") | |
| highlighted_predictions_df = predictions_df[["Prediction"]] | |
| st.write( | |
| highlighted_predictions_df.to_html(escape=False), unsafe_allow_html=True | |
| ) | |
| st.divider() | |
| ### EVALUATION METRICS COMPARISION ### | |
| st.subheader("Evaluation Metrics Comparision") # , divider='rainbow') | |
| # metric_names = "\n".join( | |
| # ["- " + evaluation_metric.name for evaluation_metric in EVALUATION_METRICS] | |
| # ) | |
| st.markdown( | |
| "The different evaluation metrics we have for the NER task are shown below, select the metrics to compare.\n" | |
| # f"{metric_names}" | |
| ) | |
| metrics_selection = [ | |
| (st.checkbox(evaluation_metric.name, value=True), evaluation_metric.name) | |
| for evaluation_metric in EVALUATION_METRICS | |
| ] | |
| metrics_to_show = list( | |
| map(lambda x: x[1], filter(lambda x: x[0], metrics_selection)) | |
| ) | |
| with st.expander("View Predictions Details"): | |
| st.write(predictions_df.to_html(escape=False), unsafe_allow_html=True) | |
| if st.button("Get Metrics!"): | |
| for evaluation_metric in EVALUATION_METRICS: | |
| predictions_df[evaluation_metric.name] = predictions_df.ner_spans.apply( | |
| lambda ner_spans: evaluation_metric.get_evaluation_metric( | |
| # metric_type=evaluation_metric_type, | |
| gt_ner_span=gt_spans, | |
| pred_ner_span=ner_spans, | |
| text=text, | |
| tags=tags, | |
| ) | |
| ) | |
| metrics_df = predictions_df.drop(["ner_spans"], axis=1) | |
| st.write( | |
| metrics_df[["Prediction"] + metrics_to_show].to_html(escape=False), | |
| unsafe_allow_html=True, | |
| ) | |