Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from tqdm.auto import tqdm | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| from IPython.display import display, HTML | |
| from transformers import AutoTokenizer | |
| from DecompX.src.decompx_utils import DecompXConfig | |
| from DecompX.src.modeling_bert import BertForSequenceClassification | |
| from DecompX.src.modeling_roberta import RobertaForSequenceClassification | |
| plt.style.use("ggplot") | |
| MODELS = ['TehranNLP-org/bert-base-uncased-cls-sst2', 'TehranNLP-org/bert-large-sst2', "WillHeld/roberta-base-sst2"] | |
| def plot_clf(tokens, logits, label_names, title="", file_name=None): | |
| print(tokens) | |
| plt.figure(figsize=(4.5, 5)) | |
| colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)] | |
| plt.barh(range(len(tokens)), logits, color=colors) | |
| plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2) | |
| plt.gca().invert_yaxis() | |
| max_limit = np.max(np.abs(logits)) + 0.2 | |
| min_limit = -0.01 if np.min(logits) > 0 else -max_limit | |
| plt.xlim(min_limit, max_limit) | |
| plt.gca().set_xticks([min_limit, max_limit]) | |
| plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold") | |
| plt.gca().set_yticks(range(len(tokens))) | |
| plt.gca().set_yticklabels(tokens) | |
| plt.gca().yaxis.tick_right() | |
| for xtick, color in zip(plt.gca().get_yticklabels(), colors): | |
| xtick.set_color(color) | |
| xtick.set_fontweight("bold") | |
| xtick.set_verticalalignment("center") | |
| for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]): | |
| xtick.set_color(color) | |
| # plt.title(title, fontsize=14, fontweight="bold") | |
| plt.title(title) | |
| plt.tight_layout() | |
| def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False): | |
| """ | |
| importance: (sent_len) | |
| """ | |
| if no_cls_sep: | |
| importance = importance[1:-1] | |
| tokenized_text = tokenized_text[1:-1] | |
| importance = importance / np.abs(importance).max() / 1.5 # Normalize | |
| if discrete: | |
| importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6 | |
| html = "<pre style='color:black; padding: 3px;'>"+prefix | |
| for i in range(len(tokenized_text)): | |
| if importance[i] >= 0: | |
| rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i]) # Wistia | |
| else: | |
| rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i])) # Wistia | |
| text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else "" | |
| color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color | |
| html += (f"<span style='" | |
| f"{color}" | |
| f"color:black; border-radius: 5px; padding: 3px;" | |
| f"font-weight: {int(800)};" | |
| "'>") | |
| html += tokenized_text[i].replace('<', "[").replace(">", "]") | |
| html += "</span> " | |
| html += "</pre>" | |
| # display(HTML(html)) | |
| return html | |
| def print_preview(decompx_outputs_df, idx=0, discrete=False): | |
| html = "" | |
| NO_CLS_SEP = False | |
| df = decompx_outputs_df | |
| for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]: | |
| if col in df and df[col][idx] is not None: | |
| if "aggregated" in col: | |
| sentence_importance = df[col].iloc[idx][0, :] | |
| if "classifier" in col: | |
| for label in range(df[col].iloc[idx].shape[-1]): | |
| sentence_importance = df[col].iloc[idx][:, label] | |
| html += print_importance( | |
| sentence_importance, | |
| df["tokens"].iloc[idx], | |
| prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20), | |
| no_cls_sep=NO_CLS_SEP, | |
| discrete=False | |
| ) | |
| break | |
| sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]] | |
| html += print_importance( | |
| sentence_importance, | |
| df["tokens"].iloc[idx], | |
| prefix=f"{col.split('_')[-1]}:".ljust(20), | |
| no_cls_sep=NO_CLS_SEP, | |
| discrete=discrete | |
| ) | |
| return "<div style='overflow:auto; background-color:white; padding: 10px;'>" + html | |
| def run_decompx(text, model): | |
| """ | |
| Provide DecompX Token Explanation of Model on Text | |
| """ | |
| SENTENCES = [text, "nothing"] | |
| CONFIGS = { | |
| "DecompX": | |
| DecompXConfig( | |
| include_biases=True, | |
| bias_decomp_type="absdot", | |
| include_LN1=True, | |
| include_FFN=True, | |
| FFN_approx_type="GeLU_ZO", | |
| include_LN2=True, | |
| aggregation="vector", | |
| include_classifier_w_pooler=True, | |
| tanh_approx_type="ZO", | |
| output_all_layers=True, | |
| output_attention=None, | |
| output_res1=None, | |
| output_LN1=None, | |
| output_FFN=None, | |
| output_res2=None, | |
| output_encoder=None, | |
| output_aggregated="norm", | |
| output_pooler="norm", | |
| output_classifier=True, | |
| ), | |
| } | |
| MODEL = model | |
| # LOAD MODEL AND TOKENIZER | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
| tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True) | |
| batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1) | |
| if "roberta" in MODEL: | |
| model = RobertaForSequenceClassification.from_pretrained(MODEL) | |
| elif "bert" in MODEL: | |
| model = BertForSequenceClassification.from_pretrained(MODEL) | |
| else: | |
| raise Exception(f"Not implented model: {MODEL}") | |
| # RUN DECOMPX | |
| with torch.no_grad(): | |
| model.eval() | |
| logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model( | |
| **tokenized_sentence, | |
| output_attentions=False, | |
| return_dict=False, | |
| output_hidden_states=True, | |
| decompx_config=CONFIGS["DecompX"] | |
| ) | |
| decompx_outputs = { | |
| "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))], | |
| "logits": logits.cpu().detach().numpy().tolist(), # (batch, classes) | |
| "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim) | |
| } | |
| ### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ### | |
| importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes) | |
| importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))] | |
| decompx_outputs["importance_last_layer_classifier"] = importance | |
| ### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ### | |
| importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len) | |
| importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len) | |
| importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))] | |
| decompx_outputs["importance_all_layers_aggregated"] = importance | |
| decompx_outputs_df = pd.DataFrame(decompx_outputs) | |
| idx = 0 | |
| pred_label = np.argmax(decompx_outputs_df.iloc[idx]["logits"], axis=-1) | |
| label = decompx_outputs_df.iloc[idx]["importance_last_layer_classifier"][:, pred_label] | |
| tokens = decompx_outputs_df.iloc[idx]["tokens"][1:-1] | |
| label = label[1:-1] | |
| label = label / np.max(np.abs(label)) | |
| plot_clf(tokens, label, ['-','+'], title=f"DecompX for Predicted Label: {pred_label}", file_name="example_sst2_our_method") | |
| return plt, print_preview(decompx_outputs_df) | |
| demo = gr.Interface( | |
| fn=run_decompx, | |
| inputs=[ | |
| gr.components.Textbox(label="Text"), | |
| gr.components.Dropdown(label="Model", choices=MODELS), | |
| ], | |
| outputs=["plot", "html"], | |
| examples=[ | |
| ["a good piece of work more often than not.", "TehranNLP-org/bert-base-uncased-cls-sst2"], | |
| ["a good piece of work more often than not.", "TehranNLP-org/bert-large-sst2"], | |
| ["a good piece of work more often than not.", "WillHeld/roberta-base-sst2"], | |
| ["A deep and meaningful film.", "TehranNLP-org/bert-base-uncased-cls-sst2"], | |
| ], | |
| cache_examples=True, | |
| title="DecompX Demo", | |
| description="This is a demo for the ACL 2023 paper [DecompX](https://github.com/mohsenfayyaz/DecompX/)" | |
| ) | |
| demo.launch() |