Spaces:
Running
Running
| import torch | |
| from transformers import MarianMTModel, AutoTokenizer | |
| import ctranslate2 | |
| from colorize import align_words | |
| import logging | |
| # Create a logger | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) # Set to debug to capture all levels of logs | |
| file_handler = logging.FileHandler('app.log', mode='a') # 'a' mode appends to the file | |
| file_handler.setLevel(logging.INFO) | |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| file_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| model_to_ar = MarianMTModel.from_pretrained("guymorlan/levanti_translate_en_ar", output_attentions=True) | |
| model_from_ar = MarianMTModel.from_pretrained("guymorlan/levanti_translate_ar_en", output_attentions=True) | |
| model_to_ar_ct2 = ctranslate2.Translator("./en_ar_ct2/") | |
| model_from_ar_ct2 = ctranslate2.Translator("./ar_en_ct2/") | |
| tokenizer_to_ar = AutoTokenizer.from_pretrained("guymorlan/levanti_translate_en_ar") | |
| tokenizer_from_ar = AutoTokenizer.from_pretrained("guymorlan/levanti_translate_ar_en") | |
| print("Done loading models") | |
| dialect_map = { | |
| "Palestinian": "P", | |
| "Syrian": "S", | |
| "Lebanese": "L", | |
| "Egyptian": "E", | |
| "驻诇住讟讬谞讬": "P", | |
| "住讜专讬": "S", | |
| "诇讘谞讜谞讬": "L", | |
| "诪爪专讬": "E" | |
| } | |
| def translate(text, ct_model, hf_model, tokenizer, to_arabic=True, | |
| threshold=None, layer=2, head=6): | |
| logger.info(f"Translating: {text}") | |
| inp_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text)) | |
| out_tokens = ct_model.translate_batch([inp_tokens])[0].hypotheses[0] | |
| out_string = tokenizer.convert_tokens_to_string(out_tokens) | |
| encoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(inp_tokens)).unsqueeze(0) | |
| decoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(["<pad>"] + out_tokens + | |
| ['</s>'])).unsqueeze(0) | |
| colorization_output = hf_model(input_ids=encoder_input_ids, | |
| decoder_input_ids=decoder_input_ids) | |
| if not threshold: | |
| if len(inp_tokens) < 10: | |
| threshold = 0.05 | |
| elif len(inp_tokens) < 20: | |
| threshold = 0.10 | |
| else: | |
| threshold = 0.05 | |
| srchtml, tgthtml = align_words(colorization_output, | |
| tokenizer, | |
| encoder_input_ids, | |
| decoder_input_ids, | |
| threshold, | |
| skip_first_src=to_arabic, | |
| skip_second_src=False, | |
| layer=layer, | |
| head=head) | |
| html = f"<div style='direction: rtl'>{srchtml}<br><br>{tgthtml}</div>" | |
| arabic = out_string if is_arabic(out_string) else text | |
| return html, arabic | |
| #%% | |
| def is_arabic(text): | |
| # return True if text has more than 50% arabic characters, False otherwise | |
| text = text.replace(" ", "") | |
| arabic_chars = 0 | |
| for c in text: | |
| if "\u0600" <= c <= "\u06FF": | |
| arabic_chars += 1 | |
| return arabic_chars / len(text) > 0.5 | |
| def run_translate(text, dialect=None): | |
| if not text: | |
| return "" | |
| if is_arabic(text): | |
| return translate(text, model_from_ar_ct2, model_from_ar, tokenizer_from_ar, | |
| to_arabic=False, threshold=None, layer=2, head=7) | |
| else: | |
| if dialect in dialect_map: | |
| dialect = dialect_map[dialect] | |
| text = f"{dialect} {text}" if dialect else text | |
| return translate(text, model_to_ar_ct2, model_to_ar, tokenizer_to_ar, | |
| to_arabic=True, threshold=None, layer=2, head=7) | |