duzx16
commited on
Commit
·
8eb45c8
1
Parent(s):
487aa2f
use inference_mode
Browse files- modeling_chatglm.py +3 -3
modeling_chatglm.py
CHANGED
|
@@ -1014,7 +1014,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1014 |
inputs = inputs.to(self.device)
|
| 1015 |
return inputs
|
| 1016 |
|
| 1017 |
-
@torch.
|
| 1018 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
| 1019 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
| 1020 |
if history is None:
|
|
@@ -1032,7 +1032,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1032 |
history = history + [(query, response)]
|
| 1033 |
return response, history
|
| 1034 |
|
| 1035 |
-
@torch.
|
| 1036 |
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
|
| 1037 |
max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
| 1038 |
return_past_key_values=False, **kwargs):
|
|
@@ -1069,7 +1069,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1069 |
else:
|
| 1070 |
yield response, new_history
|
| 1071 |
|
| 1072 |
-
@torch.
|
| 1073 |
def stream_generate(
|
| 1074 |
self,
|
| 1075 |
input_ids,
|
|
|
|
| 1014 |
inputs = inputs.to(self.device)
|
| 1015 |
return inputs
|
| 1016 |
|
| 1017 |
+
@torch.inference_mode()
|
| 1018 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
| 1019 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
| 1020 |
if history is None:
|
|
|
|
| 1032 |
history = history + [(query, response)]
|
| 1033 |
return response, history
|
| 1034 |
|
| 1035 |
+
@torch.inference_mode()
|
| 1036 |
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
|
| 1037 |
max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
| 1038 |
return_past_key_values=False, **kwargs):
|
|
|
|
| 1069 |
else:
|
| 1070 |
yield response, new_history
|
| 1071 |
|
| 1072 |
+
@torch.inference_mode()
|
| 1073 |
def stream_generate(
|
| 1074 |
self,
|
| 1075 |
input_ids,
|