Update app.py
Browse files
app.py
CHANGED
|
@@ -371,10 +371,11 @@ class Translators:
|
|
| 371 |
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 372 |
model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
| 373 |
self.input_text = self.input_text if self.input_text.endswith('.') else f'{self.input_text}.'
|
| 374 |
-
inputs = tokenizer.encode(f"Translate to {self.tl}: {self.input_text}", return_tensors="pt")
|
| 375 |
outputs = model.generate(inputs)
|
| 376 |
translation = tokenizer.decode(outputs[0])
|
| 377 |
translation = translation.replace('<pad> ', '').replace('</s>', '')
|
|
|
|
| 378 |
return translation
|
| 379 |
|
| 380 |
def bloomz(self):
|
|
|
|
| 371 |
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 372 |
model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
| 373 |
self.input_text = self.input_text if self.input_text.endswith('.') else f'{self.input_text}.'
|
| 374 |
+
inputs = tokenizer.encode(f"Translate to {self.tl}: {self.input_text} Translation:", return_tensors="pt")
|
| 375 |
outputs = model.generate(inputs)
|
| 376 |
translation = tokenizer.decode(outputs[0])
|
| 377 |
translation = translation.replace('<pad> ', '').replace('</s>', '')
|
| 378 |
+
translation = translation.split('Translation:')[-1].strip() if 'Translation:' in translation else translation.strip()
|
| 379 |
return translation
|
| 380 |
|
| 381 |
def bloomz(self):
|