Spaces:
Runtime error
Runtime error
| import random | |
| import re | |
| from poems import SAMPLE_POEMS | |
| import langid | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from icu_tokenizer import Tokenizer | |
| from transformers import pipeline | |
| MODELS = { | |
| "ALBERTI": "flax-community/alberti-bert-base-multilingual-cased", | |
| "mBERT": "bert-base-multilingual-cased" | |
| } | |
| TOPK = 50 | |
| st.set_page_config(layout="wide") | |
| def mask_line(line, language="es", restrictive=True): | |
| tokenizer = Tokenizer(lang=language) | |
| token_list = tokenizer.tokenize(line) | |
| if lang != "zh": | |
| restrictive = not all([len(token) <= 3 for token in token_list]) | |
| random_num = random.randint(0, len(token_list) - 1) | |
| random_word = token_list[random_num] | |
| if not restrictive: | |
| token_list[random_num] = "[MASK]" | |
| masked_l = " ".join(token_list) | |
| return masked_l | |
| elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()): | |
| token_list[random_num] = "[MASK]" | |
| masked_l = " ".join(token_list) | |
| return masked_l | |
| else: | |
| return mask_line(line, language) | |
| def filter_candidates(candidates, get_any_candidate=False): | |
| cand_list = [] | |
| score_list = [] | |
| for candidate in candidates: | |
| if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha(): | |
| cand = candidate["sequence"] | |
| score = candidate["score"] | |
| cand_list.append(cand) | |
| score_list.append('{0:.5f}'.format(score)) | |
| elif get_any_candidate: | |
| cand = candidate["sequence"] | |
| score = candidate["score"] | |
| cand_list.append(cand) | |
| score_list.append('{0:.5f}'.format(score)) | |
| if len(score_list) == TOPK: | |
| break | |
| if len(cand_list) < 1: | |
| return filter_candidates(candidates, get_any_candidate=True) | |
| else: | |
| return cand_list[0] | |
| def infer_candidates(nlp, line): | |
| line = re.sub("β", "-", line) | |
| line = re.sub("β", "-", line) | |
| line = re.sub("β", "'", line) | |
| line = re.sub("β¦", "...", line) | |
| inputs = nlp._parse_and_tokenize(line) | |
| outputs = nlp._forward(inputs, return_tensors=True) | |
| input_ids = inputs["input_ids"][0] | |
| masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id, | |
| as_tuple=False) | |
| logits = outputs[0, masked_index.item(), :] | |
| probs = logits.softmax(dim=0) | |
| values, predictions = probs.topk(TOPK) | |
| result = [] | |
| for v, p in zip(values.tolist(), predictions.tolist()): | |
| tokens = input_ids.numpy() | |
| tokens[masked_index] = p | |
| # Filter padding out: | |
| tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)] | |
| l = [] | |
| token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens] | |
| for idx, token in enumerate(token_list): | |
| if token.startswith('##'): | |
| l[-1] += token[2:] | |
| elif idx == masked_index.item(): | |
| l += ['<b style="color: #ff0000;">', token, "</b>"] | |
| else: | |
| l += [token] | |
| sequence = " ".join(l).strip() | |
| result.append( | |
| { | |
| "sequence": sequence, | |
| "score": v, | |
| "token": p, | |
| "token_str": nlp.tokenizer.decode(p), | |
| "masked_index": masked_index.item() | |
| } | |
| ) | |
| return result | |
| def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"): | |
| nlp = pipeline("fill-mask", model=ml_model) | |
| unmasked_lines = [] | |
| masked_lines = [] | |
| for line in poem: | |
| if line == "": | |
| unmasked_lines.append("") | |
| masked_lines.append("") | |
| continue | |
| if masking: | |
| masked_line = mask_line(line, language) | |
| else: | |
| masked_line = line | |
| masked_lines.append(masked_line) | |
| unmasked_line_candidates = infer_candidates(nlp, masked_line) | |
| unmasked_line = filter_candidates(unmasked_line_candidates) | |
| unmasked_lines.append(unmasked_line) | |
| unmasked_poem = "<br>".join(unmasked_lines) | |
| return unmasked_poem, masked_lines | |
| instructions_text_0 = st.sidebar.markdown( | |
| """# ALBERTI vs BERT π₯ | |
| We present ALBERTI, our BERT-based multilingual model for poetry.""") | |
| instructions_text_1 = st.sidebar.markdown( | |
| """We have trained bert on a huge (for poetry, that is) corpus of | |
| multilingual poetry to try to get a more 'poetic' model. This is the result | |
| of our work. | |
| You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""") | |
| sample_chooser = st.sidebar.selectbox( | |
| "Choose a poem", | |
| list(SAMPLE_POEMS.keys()) | |
| ) | |
| instructions_text_2 = st.sidebar.markdown("""# How to use | |
| You can choose from a list of example poems in Spanish, English, French, German, | |
| Chinese and Arabic, but you can also paste a poem, or write it yourself! | |
| Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen | |
| poem, randomly masking one word per verse, and get the two new versions for each of the models. | |
| The list of languages used on the training of ALBERTI are: | |
| * Arabic | |
| * Chinese | |
| * Czech | |
| * English | |
| * Finnish | |
| * French | |
| * German | |
| * Hungarian | |
| * Italian | |
| * Portuguese | |
| * Russian | |
| * Spanish""") | |
| col1, col2, col3 = st.columns(3) | |
| st.markdown( | |
| """ | |
| <style> | |
| label { | |
| font-size: 1rem !important; | |
| font-weight: bold !important; | |
| } | |
| .block-container { | |
| padding-left: 1rem !important; | |
| padding-right: 1rem !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| if sample_chooser: | |
| model_list = set(MODELS.values()) | |
| user_input = col1.text_area("Input poem", | |
| "\n".join(SAMPLE_POEMS[sample_chooser]), | |
| height=600) | |
| poem = user_input.split("\n") | |
| rewrite_button = col1.button("Rewrite!") | |
| if "[MASK]" in user_input or "<mask>" in user_input: | |
| col1.error("You don't have to mask the poem, we'll do it for you!") | |
| if rewrite_button: | |
| lang = langid.classify(user_input)[0] | |
| unmasked_poem, masked_poem = rewrite_poem(poem, language=lang) | |
| user_input_2 = col2.write(f"""<b>Output poem from ALBERTI</b> | |
| {unmasked_poem}""", unsafe_allow_html=True) | |
| unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"], | |
| masking=False) | |
| user_input_3 = col3.write(f"""<b>Output poem from mBERT</b> | |
| {unmasked_poem_2}""", unsafe_allow_html=True) | |