Spaces:
Running
Running
| import numpy as np | |
| import pandas as pd | |
| import re | |
| import os | |
| import cloudpickle | |
| from transformers import BartTokenizerFast, TFAutoModelForSeq2SeqLM | |
| import tensorflow as tf | |
| import spacy | |
| import streamlit as st | |
| import logging | |
| import traceback | |
| from scraper import scrape_text | |
| # Training data https://www.kaggle.com/datasets/vladimirvorobevv/chatgpt-paraphrases | |
| os.environ['TF_USE_LEGACY_KERAS'] = "1" | |
| CHECKPOINT = "facebook/bart-base" | |
| INPUT_N_TOKENS = 70 | |
| TARGET_N_TOKENS = 70 | |
| def load_models(): | |
| nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0')) | |
| tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT) | |
| model = TFAutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT) | |
| model.load_weights(os.path.join("models", "bart_en_paraphraser.h5"), by_name=True) | |
| logging.warning('Loaded models') | |
| return nlp, tokenizer, model | |
| nlp, tokenizer, model = load_models() | |
| def inference_tokenize(input_: list, n_tokens: int): | |
| tokenized_data = tokenizer(text=input_, max_length=n_tokens, truncation=True, padding="max_length", return_tensors="tf") | |
| return tokenizer, tokenized_data | |
| def process_result(result: str): | |
| result = result.split('", "') | |
| result = result if isinstance(result, str) else result[0] | |
| result = re.sub('(<.*?>)', "", result).strip() | |
| result = re.sub('^\"', "", result).strip() | |
| return result | |
| def inference(txt): | |
| try: | |
| inference_tokenizer, tokenized_data = inference_tokenize(input_=txt, n_tokens=INPUT_N_TOKENS) | |
| pred = model.generate(**tokenized_data, max_new_tokens=TARGET_N_TOKENS, num_beams=4) | |
| result = [process_result(inference_tokenizer.decode(p, skip_special_tokens=True)) for p in pred] | |
| logging.warning(f'paraphrased_result: {result}') | |
| return result | |
| except: | |
| logging.warning(traceback.format_exc()) | |
| raise | |
| def inference_long_text(txt, n_sents): | |
| paraphrased_txt = [] | |
| input_txt = [] | |
| doc = nlp(txt) | |
| n = 0 | |
| for sent in doc.sents: | |
| if n >= n_sents: | |
| break | |
| if len(sent.text.split()) >= 3: | |
| input_txt.append(sent.text) | |
| n += 1 | |
| with st.spinner('Rewriting...'): | |
| paraphrased_txt = inference(input_txt) | |
| return input_txt, paraphrased_txt | |
| ############## ENTRY POINT START ####################### | |
| def main(): | |
| st.markdown('''<h3>Text Rewriter</h3>''', unsafe_allow_html=True) | |
| input_type = st.radio('Select an option:', ['Paste URL of the text', 'Paste the text'], | |
| horizontal=True) | |
| n_sents = st.slider('Select the number of sentences to process', 5, 30, 10) | |
| scrape_error = None | |
| paraphrase_error = None | |
| paraphrased_txt = None | |
| input_txt = None | |
| if input_type == 'Paste URL of the text': | |
| input_url = st.text_input("Paste URL of the text", "") | |
| if (st.button("Submit")) or (input_url): | |
| with st.status("Processing...", expanded=True) as status: | |
| status.empty() | |
| # Scraping data Start | |
| try: | |
| st.info("Scraping data from the URL.", icon="βΉοΈ") | |
| input_txt = scrape_text(input_url) | |
| st.success("Successfully scraped the data.", icon="β ") | |
| except Exception as e: | |
| input_txt = None | |
| scrape_error = str(e) | |
| # Scraping data End | |
| if input_txt is not None: | |
| input_txt = re.sub(r'\n+',' ', input_txt) | |
| # Paraphrasing start | |
| try: | |
| st.info("Rewriting the text. This takes time.", icon="βΉοΈ") | |
| input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents) | |
| except Exception as e: | |
| paraphrased_txt = None | |
| paraphrase_error = str(e) | |
| if paraphrased_txt is not None: | |
| st.success("Successfully rewrote the text.", icon="β ") | |
| else: | |
| st.error("Encountered an error while rewriting the text.", icon="π¨") | |
| # Paraphrasing end | |
| else: | |
| st.error("Encountered an error while scraping the data.", icon="π¨") | |
| if (scrape_error is None) and (paraphrase_error is None): | |
| status.update(label="Done", state="complete", expanded=False) | |
| else: | |
| status.update(label="Error", state="error", expanded=False) | |
| if scrape_error is not None: | |
| st.error(f"Scrape Error: \n{scrape_error}", icon="π¨") | |
| else: | |
| if paraphrase_error is not None: | |
| st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="π¨") | |
| else: | |
| result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)] | |
| result = "<br><br>".join(result) | |
| result = result.replace("$", "$") | |
| st.markdown(f"{result}", unsafe_allow_html=True) | |
| else: | |
| input_txt = st.text_area("Enter the text. (Ensure the text is grammatically correct and has punctuations at the right places):", "", height=150) | |
| if (st.button("Submit")) or (input_txt): | |
| with st.status("Processing...", expanded=True) as status: | |
| input_txt = re.sub(r'\n+',' ', input_txt) | |
| # Paraphrasing start | |
| try: | |
| st.info("Rewriting the text. This takes time.", icon="βΉοΈ") | |
| input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents) | |
| except Exception as e: | |
| paraphrased_txt = None | |
| paraphrase_error = str(e) | |
| if paraphrased_txt is not None: | |
| st.success("Successfully rewrote the text.", icon="β ") | |
| else: | |
| st.error("Encountered an error while rewriting the text.", icon="π¨") | |
| # Paraphrasing end | |
| if paraphrase_error is None: | |
| status.update(label="Done", state="complete", expanded=False) | |
| else: | |
| status.update(label="Error", state="error", expanded=False) | |
| if paraphrase_error is not None: | |
| st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="π¨") | |
| else: | |
| result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)] | |
| result = "<br><br>".join(result) | |
| result = result.replace("$", "$") | |
| st.markdown(f"{result}", unsafe_allow_html=True) | |
| ############## ENTRY POINT END ####################### | |
| if __name__ == "__main__": | |
| main() |