Ankerkraut commited on
Commit
f016689
·
1 Parent(s): c740076

move to one device

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -45,7 +45,7 @@ last_messages = []
45
  def load_model():
46
  ankerbot_model = AutoModelForCausalLM.from_pretrained(
47
  model_name,
48
- device_map="cpu",
49
  torch_dtype=torch.float16,
50
  use_cache=True,
51
  offload_folder="../offload"
@@ -55,9 +55,8 @@ def load_model():
55
  torch_dtype=torch.float16,
56
  truncation=True,
57
  padding=True, )
58
- ankerbot_model.to("cuda")
59
- generator = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=False) # True for flash-attn2 else False
60
- generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=False) # True for flash-attn2 else False
61
  return (generator, generator_mini)
62
 
63
  _model_cache = None
@@ -81,9 +80,9 @@ def generate_response(query, context, prompts, max_tokens, temperature, top_p, g
81
  Du bekommst Kundenanfragen zum Beispiel zu einer Bestellung, antworte Anhand des zur Verfügunggestellten Kontextes.
82
  Tu so, als wär der Kontext Bestandteil deines Wissens. Sprich den Kunden persönlich an.
83
  Nenne nichts außerhalb des Kontext.
84
- Konversation: {",".join(last_messages)}
85
  Kontext Kundenservice: {context}
86
  <|im_end|>
 
87
  <|im_start|>user
88
  Frage: {query}
89
  <|im_end|>
@@ -97,9 +96,9 @@ def generate_response(query, context, prompts, max_tokens, temperature, top_p, g
97
  Du bekommst im Kontext Informationen zu Rezepten und Gerichten.
98
  Tu so, als wär der Kontext Bestandteil deines Wissens. Sprich den Kunden persönlich an.
99
  Nenne nichts außerhalb des Kontext.
100
- Konversation: {",".join(last_messages)}
101
  Kontext Rezepte: {context}
102
  <|im_end|>
 
103
  <|im_start|>user
104
  Frage: {query}
105
  <|im_end|>
@@ -113,9 +112,9 @@ def generate_response(query, context, prompts, max_tokens, temperature, top_p, g
113
  Du bekommst im Kontext Informationen zu Produkte, nach denen gefragt ist, oder welche ähnlich sein könnten.
114
  Tu so, als wär der Kontext Bestandteil deines Wissens. Sprich den Kunden persönlich an.
115
  Nenne nichts außerhalb des Kontext.
116
- Konversation: {",".join(last_messages)}
117
  Kontext Produkte: {context}
118
  <|im_end|>
 
119
  <|im_start|>user
120
  Frage: {query}
121
  <|im_end|>
@@ -173,6 +172,7 @@ def respond(
173
  Frage: {query}
174
  <|im_end|>
175
  <|im_start|>assistant"""
 
176
  refined_context = generator[1](system_message, do_sample=True, padding=True, truncation=True, top_p=0.95, max_new_tokens=150)
177
  # Retrieve relevant context from Qdrant
178
  if "rezept" in query.lower() or "gericht" in query.lower():
@@ -184,7 +184,7 @@ def respond(
184
 
185
  context = search_qdrant_with_context(query + " " + refined_context[0]["generated_text"].split("assistant\n").pop(), collection_name)
186
  answer = generate_response(query, context, last_messages, max_tokens, temperature, top_p, generator[0])
187
- full_conv = f"Nutzer:{query};Assistent:{answer}"
188
  if len(last_messages) > 5:
189
  last_messages.pop(0)
190
  last_messages.append(full_conv)
 
45
  def load_model():
46
  ankerbot_model = AutoModelForCausalLM.from_pretrained(
47
  model_name,
48
+ device_map="auto",
49
  torch_dtype=torch.float16,
50
  use_cache=True,
51
  offload_folder="../offload"
 
55
  torch_dtype=torch.float16,
56
  truncation=True,
57
  padding=True, )
58
+ generator = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=False, device="cuda:0") # True for flash-attn2 else False
59
+ generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, trust_remote_code=False, device="cuda:0") # True for flash-attn2 else False
 
60
  return (generator, generator_mini)
61
 
62
  _model_cache = None
 
80
  Du bekommst Kundenanfragen zum Beispiel zu einer Bestellung, antworte Anhand des zur Verfügunggestellten Kontextes.
81
  Tu so, als wär der Kontext Bestandteil deines Wissens. Sprich den Kunden persönlich an.
82
  Nenne nichts außerhalb des Kontext.
 
83
  Kontext Kundenservice: {context}
84
  <|im_end|>
85
+ {"".join(last_messages)}
86
  <|im_start|>user
87
  Frage: {query}
88
  <|im_end|>
 
96
  Du bekommst im Kontext Informationen zu Rezepten und Gerichten.
97
  Tu so, als wär der Kontext Bestandteil deines Wissens. Sprich den Kunden persönlich an.
98
  Nenne nichts außerhalb des Kontext.
 
99
  Kontext Rezepte: {context}
100
  <|im_end|>
101
+ {"".join(last_messages)}
102
  <|im_start|>user
103
  Frage: {query}
104
  <|im_end|>
 
112
  Du bekommst im Kontext Informationen zu Produkte, nach denen gefragt ist, oder welche ähnlich sein könnten.
113
  Tu so, als wär der Kontext Bestandteil deines Wissens. Sprich den Kunden persönlich an.
114
  Nenne nichts außerhalb des Kontext.
 
115
  Kontext Produkte: {context}
116
  <|im_end|>
117
+ {"".join(last_messages)}
118
  <|im_start|>user
119
  Frage: {query}
120
  <|im_end|>
 
172
  Frage: {query}
173
  <|im_end|>
174
  <|im_start|>assistant"""
175
+ system_message = system_message.to("cuda") if torch.cuda.is_available() else system_message
176
  refined_context = generator[1](system_message, do_sample=True, padding=True, truncation=True, top_p=0.95, max_new_tokens=150)
177
  # Retrieve relevant context from Qdrant
178
  if "rezept" in query.lower() or "gericht" in query.lower():
 
184
 
185
  context = search_qdrant_with_context(query + " " + refined_context[0]["generated_text"].split("assistant\n").pop(), collection_name)
186
  answer = generate_response(query, context, last_messages, max_tokens, temperature, top_p, generator[0])
187
+ full_conv = f"<|im_start|>user {query}<|im_end|><|im_start|>assistent {answer}<|im_end|>"
188
  if len(last_messages) > 5:
189
  last_messages.pop(0)
190
  last_messages.append(full_conv)