Spaces:
Sleeping
Sleeping
Commit
·
a4dec41
1
Parent(s):
4b2d9b2
enable cuda, cpu takes forever
Browse files
app.py
CHANGED
|
@@ -40,16 +40,16 @@ client.add(collection_name="recipes",
|
|
| 40 |
model_name = "LeoLM/leo-hessianai-13b-chat"
|
| 41 |
|
| 42 |
last_messages = []
|
| 43 |
-
|
| 44 |
def load_model():
|
| 45 |
ankerbot_model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
model_name,
|
| 47 |
-
device_map="
|
| 48 |
torch_dtype=torch.float16,
|
| 49 |
use_cache=True,
|
| 50 |
offload_folder="../offload"
|
| 51 |
)
|
| 52 |
-
|
| 53 |
ankerbot_tokenizer = AutoTokenizer.from_pretrained(model_name,
|
| 54 |
torch_dtype=torch.float16,
|
| 55 |
truncation=True,
|
|
@@ -60,7 +60,7 @@ def load_model():
|
|
| 60 |
|
| 61 |
_model_cache = None
|
| 62 |
|
| 63 |
-
|
| 64 |
def get_model():
|
| 65 |
global _model_cache
|
| 66 |
if _model_cache is None:
|
|
@@ -69,7 +69,7 @@ def get_model():
|
|
| 69 |
_model_cache = load_model()
|
| 70 |
return _model_cache
|
| 71 |
|
| 72 |
-
|
| 73 |
def generate_response(query, context, prompts, max_tokens, temperature, top_p, generator):
|
| 74 |
system_message_support = f"""<|im_start|>system
|
| 75 |
Rolle: Du bist der KI-Assistent für Kundenservice, der im Namen des Unternehmens und Gewürzmanufaktur Ankerkraut handelt und Antworten aus der Ich-Perspektive, basierend auf den bereitgestellten Informationen gibt.
|
|
@@ -154,7 +154,7 @@ def search_qdrant_with_context(query_text, collection_name, top_k=3):
|
|
| 154 |
print("Retrieved Text ", retrieved_texts)
|
| 155 |
|
| 156 |
return retrieved_texts
|
| 157 |
-
|
| 158 |
def respond(
|
| 159 |
query,
|
| 160 |
history: list[tuple[str, str]],
|
|
@@ -186,6 +186,7 @@ def respond(
|
|
| 186 |
if len(last_messages) > 5:
|
| 187 |
last_messages.pop(0)
|
| 188 |
last_messages.append(full_conv)
|
|
|
|
| 189 |
return answer
|
| 190 |
|
| 191 |
"""
|
|
|
|
| 40 |
model_name = "LeoLM/leo-hessianai-13b-chat"
|
| 41 |
|
| 42 |
last_messages = []
|
| 43 |
+
@spaces.GPU
|
| 44 |
def load_model():
|
| 45 |
ankerbot_model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
model_name,
|
| 47 |
+
device_map="cuda:0",
|
| 48 |
torch_dtype=torch.float16,
|
| 49 |
use_cache=True,
|
| 50 |
offload_folder="../offload"
|
| 51 |
)
|
| 52 |
+
ankerbot_model.gradient_checkpointing_enable()
|
| 53 |
ankerbot_tokenizer = AutoTokenizer.from_pretrained(model_name,
|
| 54 |
torch_dtype=torch.float16,
|
| 55 |
truncation=True,
|
|
|
|
| 60 |
|
| 61 |
_model_cache = None
|
| 62 |
|
| 63 |
+
@spaces.GPU
|
| 64 |
def get_model():
|
| 65 |
global _model_cache
|
| 66 |
if _model_cache is None:
|
|
|
|
| 69 |
_model_cache = load_model()
|
| 70 |
return _model_cache
|
| 71 |
|
| 72 |
+
@spaces.GPU
|
| 73 |
def generate_response(query, context, prompts, max_tokens, temperature, top_p, generator):
|
| 74 |
system_message_support = f"""<|im_start|>system
|
| 75 |
Rolle: Du bist der KI-Assistent für Kundenservice, der im Namen des Unternehmens und Gewürzmanufaktur Ankerkraut handelt und Antworten aus der Ich-Perspektive, basierend auf den bereitgestellten Informationen gibt.
|
|
|
|
| 154 |
print("Retrieved Text ", retrieved_texts)
|
| 155 |
|
| 156 |
return retrieved_texts
|
| 157 |
+
@spaces.GPU
|
| 158 |
def respond(
|
| 159 |
query,
|
| 160 |
history: list[tuple[str, str]],
|
|
|
|
| 186 |
if len(last_messages) > 5:
|
| 187 |
last_messages.pop(0)
|
| 188 |
last_messages.append(full_conv)
|
| 189 |
+
print(last_messages)
|
| 190 |
return answer
|
| 191 |
|
| 192 |
"""
|