jayebaku commited on
Commit
c78b1ba
·
verified ·
1 Parent(s): 9317009

Update qa_summary.py

Browse files
Files changed (1) hide show
  1. qa_summary.py +7 -4
qa_summary.py CHANGED
@@ -2,20 +2,22 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
2
 
3
 
4
  def generate_answer(llm_name, texts, query, mode='validate'):
5
-
6
  if llm_name == 'solar':
7
  tokenizer = AutoTokenizer.from_pretrained("Upstage/SOLAR-10.7B-Instruct-v1.0", use_fast=True)
8
  llm_model = AutoModelForCausalLM.from_pretrained(
9
  "Upstage/SOLAR-10.7B-Instruct-v1.0",
10
  device_map="auto", #device_map="cuda"
11
- #torch_dtype=torch.float16,)
 
12
 
13
  elif llm_name == 'mistral':
14
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", use_fast=True)
15
  llm_model = AutoModelForCausalLM.from_pretrained(
16
  "mistralai/Mistral-7B-Instruct-v0.2",
17
  device_map="auto", #device_map="cuda"
18
- #torch_dtype=torch.float16,)
 
19
 
20
  elif llm_name == 'phi3mini':
21
  tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", use_fast=True)
@@ -23,7 +25,8 @@ def generate_answer(llm_name, texts, query, mode='validate'):
23
  "microsoft/Phi-3-mini-128k-instruct",
24
  device_map="auto",
25
  torch_dtype="auto",
26
- trust_remote_code=True,)
 
27
 
28
  template_texts =""
29
  for i, text in enumerate(texts):
 
2
 
3
 
4
  def generate_answer(llm_name, texts, query, mode='validate'):
5
+
6
  if llm_name == 'solar':
7
  tokenizer = AutoTokenizer.from_pretrained("Upstage/SOLAR-10.7B-Instruct-v1.0", use_fast=True)
8
  llm_model = AutoModelForCausalLM.from_pretrained(
9
  "Upstage/SOLAR-10.7B-Instruct-v1.0",
10
  device_map="auto", #device_map="cuda"
11
+ #torch_dtype=torch.float16,
12
+ )
13
 
14
  elif llm_name == 'mistral':
15
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", use_fast=True)
16
  llm_model = AutoModelForCausalLM.from_pretrained(
17
  "mistralai/Mistral-7B-Instruct-v0.2",
18
  device_map="auto", #device_map="cuda"
19
+ #torch_dtype=torch.float16,
20
+ )
21
 
22
  elif llm_name == 'phi3mini':
23
  tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", use_fast=True)
 
25
  "microsoft/Phi-3-mini-128k-instruct",
26
  device_map="auto",
27
  torch_dtype="auto",
28
+ trust_remote_code=True,
29
+ )
30
 
31
  template_texts =""
32
  for i, text in enumerate(texts):