TiberiuCristianLeon commited on
Commit
8f9c8c9
·
verified ·
1 Parent(s): e205b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -371,10 +371,11 @@ class Translators:
371
  tokenizer = AutoTokenizer.from_pretrained(self.model_name)
372
  model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
373
  self.input_text = self.input_text if self.input_text.endswith('.') else f'{self.input_text}.'
374
- inputs = tokenizer.encode(f"Translate to {self.tl}: {self.input_text}", return_tensors="pt")
375
  outputs = model.generate(inputs)
376
  translation = tokenizer.decode(outputs[0])
377
  translation = translation.replace('<pad> ', '').replace('</s>', '')
 
378
  return translation
379
 
380
  def bloomz(self):
 
371
  tokenizer = AutoTokenizer.from_pretrained(self.model_name)
372
  model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
373
  self.input_text = self.input_text if self.input_text.endswith('.') else f'{self.input_text}.'
374
+ inputs = tokenizer.encode(f"Translate to {self.tl}: {self.input_text} Translation:", return_tensors="pt")
375
  outputs = model.generate(inputs)
376
  translation = tokenizer.decode(outputs[0])
377
  translation = translation.replace('<pad> ', '').replace('</s>', '')
378
+ translation = translation.split('Translation:')[-1].strip() if 'Translation:' in translation else translation.strip()
379
  return translation
380
 
381
  def bloomz(self):