TiberiuCristianLeon commited on
Commit
153cb2b
·
verified ·
1 Parent(s): 4e055aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -70,17 +70,17 @@ class Translators:
70
  model.half() # recommended
71
  model.eval()
72
  # Translating from one or several sentences to a sole language
73
- src_tokens = tokenizer.encode_source_tokens_to_input_ids([self.input_text, ], target_language=self.tl)
74
  # src_tokens = src_tokens.to(self.device)
75
- generated_tokens = model.generate(src_tokens)
76
- return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
77
  # Translating from one or several sentences to corresponding languages
78
  # src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
79
  # generated_tokens = model.generate(src_tokens.to(self.device))
80
  # results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
81
- # with torch.no_grad():
82
- # generated_tokens = model.generate(src_tokens)
83
- # return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
84
 
85
  def hplt(self, opus = False):
86
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']
 
70
  model.half() # recommended
71
  model.eval()
72
  # Translating from one or several sentences to a sole language
73
+ src_tokens = tokenizer.encode_source_tokens_to_input_ids(self.input_text, target_language=self.tl)
74
  # src_tokens = src_tokens.to(self.device)
75
+ # generated_tokens = model.generate(src_tokens)
76
+ # return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
77
  # Translating from one or several sentences to corresponding languages
78
  # src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
79
  # generated_tokens = model.generate(src_tokens.to(self.device))
80
  # results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
81
+ with torch.no_grad():
82
+ generated_tokens = model.generate(src_tokens)
83
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
84
 
85
  def hplt(self, opus = False):
86
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']