TiberiuCristianLeon commited on
Commit
2ad2ba0
·
verified ·
1 Parent(s): 153cb2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -65,9 +65,9 @@ class Translators:
65
  def mitre(self):
66
  from transformers import AutoModel, AutoTokenizer
67
  tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, use_fast=False)
68
- # model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True).to(self.device)
69
  model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
70
- model.half() # recommended
71
  model.eval()
72
  # Translating from one or several sentences to a sole language
73
  src_tokens = tokenizer.encode_source_tokens_to_input_ids(self.input_text, target_language=self.tl)
@@ -80,7 +80,7 @@ class Translators:
80
  # results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
81
  with torch.no_grad():
82
  generated_tokens = model.generate(src_tokens)
83
- return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
84
 
85
  def hplt(self, opus = False):
86
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']
 
65
  def mitre(self):
66
  from transformers import AutoModel, AutoTokenizer
67
  tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, use_fast=False)
68
+ # model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True, use_fast=False).to(self.device)
69
  model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
70
+ # model.half() # recommended for GPU
71
  model.eval()
72
  # Translating from one or several sentences to a sole language
73
  src_tokens = tokenizer.encode_source_tokens_to_input_ids(self.input_text, target_language=self.tl)
 
80
  # results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
81
  with torch.no_grad():
82
  generated_tokens = model.generate(src_tokens)
83
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
84
 
85
  def hplt(self, opus = False):
86
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']