Update app.py
Browse files
app.py
CHANGED
|
@@ -104,7 +104,7 @@ class Translators:
|
|
| 104 |
return self.HelsinkiNLP_mulroa()
|
| 105 |
except KeyError as error:
|
| 106 |
return f"Error: Translation direction {self.sl} to {self.tl} is not supported by Helsinki Translation Models", error
|
| 107 |
-
|
| 108 |
def LLaMAX(self):
|
| 109 |
pipe = pipeline("text-generation", model="LLaMAX/LLaMAX3-8B")
|
| 110 |
messages = [
|
|
@@ -163,8 +163,8 @@ class Translators:
|
|
| 163 |
model = T5ForConditionalGeneration.from_pretrained(self.model_name, device_map="auto")
|
| 164 |
prompt = f"translate {self.sl} to {self.tl}: {self.input_text}"
|
| 165 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
| 166 |
-
output_ids = model.generate(input_ids, max_length=512)
|
| 167 |
-
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
| 168 |
return translated_text
|
| 169 |
|
| 170 |
def mbart_many_to_many(self):
|
|
@@ -519,70 +519,24 @@ st.session_state["sselected_language"] = sselected_language
|
|
| 519 |
st.session_state["tselected_language"] = tselected_language
|
| 520 |
st.session_state["model_name"] = model_name
|
| 521 |
|
| 522 |
-
|
| 523 |
-
# st.write(magic)
|
| 524 |
-
f'Selected language combination: {sselected_language} - {tselected_language}. Selected model: {model_name}'
|
| 525 |
|
| 526 |
with st.container(border=None, width="stretch", height="content", horizontal=False, horizontal_alignment="center", vertical_alignment="center", gap="small"):
|
| 527 |
submit_button = st.button("Translate")
|
| 528 |
-
# Show text area with placeholder
|
| 529 |
# translated_textarea = st.empty()
|
| 530 |
# message_textarea = st.empty()
|
| 531 |
# translated_textarea.text_area(":green[Translation:]", placeholder="Translation area", value='')
|
| 532 |
# message_textarea.text_input(":blue[Messages:]", placeholder="Messages area", value='')
|
| 533 |
|
| 534 |
-
# Handle the submit button click
|
| 535 |
-
if submit_button:
|
| 536 |
with st.spinner("Translating...", show_time=True):
|
| 537 |
-
translated_text, message = translate_text(model_name, sselected_language, tselected_language, input_text)
|
| 538 |
-
# if model_name.startswith('Helsinki-NLP'):
|
| 539 |
-
# # input_ids = tokenizer.encode(input_text, return_tensors='pt')
|
| 540 |
-
# # # Perform translation
|
| 541 |
-
# # output_ids = model.generate(input_ids)
|
| 542 |
-
# # # Decode the translated text
|
| 543 |
-
# # translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 544 |
-
# # Use a pipeline as a high-level helper
|
| 545 |
-
# try:
|
| 546 |
-
# model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
|
| 547 |
-
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 548 |
-
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 549 |
-
# pipe = pipeline("translation", model=model, tokenizer=tokenizer)
|
| 550 |
-
# except (EnvironmentError, OSError):
|
| 551 |
-
# model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
|
| 552 |
-
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 553 |
-
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 554 |
-
# pipe = pipeline("translation", model=model, tokenizer=tokenizer)
|
| 555 |
-
# translation = pipe(input_text)
|
| 556 |
-
# translated_text = translation[0]['translation_text']
|
| 557 |
-
|
| 558 |
-
# elif model_name.startswith('t5'):
|
| 559 |
-
# tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 560 |
-
# model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
|
| 561 |
-
# prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
|
| 562 |
-
# input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
|
| 563 |
-
# # Perform translation
|
| 564 |
-
# output_ids = model.generate(input_ids)
|
| 565 |
-
# # Decode the translated text
|
| 566 |
-
# translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 567 |
-
# elif 'Unbabel' in model_name:
|
| 568 |
-
# pipe = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto")
|
| 569 |
-
# # We use the tokenizer’s chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
|
| 570 |
-
# messages = [{"role": "user",
|
| 571 |
-
# "content": f"Translate the following text from {sselected_language} into {tselected_language}.\n{sselected_language}: {input_text}.\n{tselected_language}:"}]
|
| 572 |
-
# prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
| 573 |
-
# outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
|
| 574 |
-
# translated_text = outputs[0]["generated_text"]
|
| 575 |
-
# start_marker = "<end_of_turn>"
|
| 576 |
-
# if start_marker in translated_text:
|
| 577 |
-
# translated_text = translated_text.split(start_marker)[1].strip()
|
| 578 |
-
# translated_text = translated_text.replace('Answer:', '').strip() if translated_text.startswith('Answer:') else translated_text
|
| 579 |
-
|
| 580 |
-
# Display the translated text
|
| 581 |
print(f"Translated from {sselected_language} to {tselected_language} using {model_name}.", input_text, translated_text)
|
| 582 |
-
#
|
| 583 |
# translated_textarea.text_area(":green[Translation:]", value=translated_text)
|
| 584 |
# message_textarea.text_input(":blue[Message:]", value=message)
|
| 585 |
-
st.text_area(":green[Translation:]",
|
| 586 |
# st.success(message, icon=":material/check:") st.info(message, icon="ℹ️"), st.warning(message, icon=":material/warning:"), error(message, icon=":material/error:"), st.exception
|
| 587 |
st.info(message, icon=":material/info:")
|
| 588 |
# st.text_input(":blue[Messages:]", value=message)
|
|
|
|
| 104 |
return self.HelsinkiNLP_mulroa()
|
| 105 |
except KeyError as error:
|
| 106 |
return f"Error: Translation direction {self.sl} to {self.tl} is not supported by Helsinki Translation Models", error
|
| 107 |
+
|
| 108 |
def LLaMAX(self):
|
| 109 |
pipe = pipeline("text-generation", model="LLaMAX/LLaMAX3-8B")
|
| 110 |
messages = [
|
|
|
|
| 163 |
model = T5ForConditionalGeneration.from_pretrained(self.model_name, device_map="auto")
|
| 164 |
prompt = f"translate {self.sl} to {self.tl}: {self.input_text}"
|
| 165 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
| 166 |
+
output_ids = model.generate(input_ids, max_length=512) # Perform translation
|
| 167 |
+
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() # Decode the translated text
|
| 168 |
return translated_text
|
| 169 |
|
| 170 |
def mbart_many_to_many(self):
|
|
|
|
| 519 |
st.session_state["tselected_language"] = tselected_language
|
| 520 |
st.session_state["model_name"] = model_name
|
| 521 |
|
| 522 |
+
st.write(f'Selected language combination: {sselected_language} - {tselected_language}. Selected model: {model_name}')
|
|
|
|
|
|
|
| 523 |
|
| 524 |
with st.container(border=None, width="stretch", height="content", horizontal=False, horizontal_alignment="center", vertical_alignment="center", gap="small"):
|
| 525 |
submit_button = st.button("Translate")
|
| 526 |
+
# Show text area with placeholder also before translating
|
| 527 |
# translated_textarea = st.empty()
|
| 528 |
# message_textarea = st.empty()
|
| 529 |
# translated_textarea.text_area(":green[Translation:]", placeholder="Translation area", value='')
|
| 530 |
# message_textarea.text_input(":blue[Messages:]", placeholder="Messages area", value='')
|
| 531 |
|
| 532 |
+
if submit_button: # Handle the submit button click
|
|
|
|
| 533 |
with st.spinner("Translating...", show_time=True):
|
| 534 |
+
translated_text, message = translate_text(model_name, sselected_language, tselected_language, input_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
print(f"Translated from {sselected_language} to {tselected_language} using {model_name}.", input_text, translated_text)
|
| 536 |
+
# Display the translated text
|
| 537 |
# translated_textarea.text_area(":green[Translation:]", value=translated_text)
|
| 538 |
# message_textarea.text_input(":blue[Message:]", value=message)
|
| 539 |
+
st.text_area(":green[Translation:]", value=translated_text)
|
| 540 |
# st.success(message, icon=":material/check:") st.info(message, icon="ℹ️"), st.warning(message, icon=":material/warning:"), error(message, icon=":material/error:"), st.exception
|
| 541 |
st.info(message, icon=":material/info:")
|
| 542 |
# st.text_input(":blue[Messages:]", value=message)
|