duzx16
		
	commited on
		
		
					Commit 
							
							·
						
						7fabe56
	
1
								Parent(s):
							
							efb7a1e
								
Fix use_cache=False
Browse files- modeling_chatglm.py +5 -2
    	
        modeling_chatglm.py
    CHANGED
    
    | @@ -897,6 +897,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 897 | 
             
                        past_key_values: Optional[torch.Tensor] = None,
         | 
| 898 | 
             
                        attention_mask: Optional[torch.Tensor] = None,
         | 
| 899 | 
             
                        position_ids: Optional[torch.Tensor] = None,
         | 
|  | |
| 900 | 
             
                        is_first_forward: bool = True,
         | 
| 901 | 
             
                        **kwargs
         | 
| 902 | 
             
                ) -> dict:
         | 
| @@ -904,7 +905,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 904 | 
             
                    if position_ids is None:
         | 
| 905 | 
             
                        position_ids = self.get_position_ids(input_ids, device=input_ids.device)
         | 
| 906 | 
             
                    if not is_first_forward:
         | 
| 907 | 
            -
                        if  | 
| 908 | 
             
                            position_ids = position_ids[..., -1:]
         | 
| 909 | 
             
                            input_ids = input_ids[:, -1:]
         | 
| 910 | 
             
                    return {
         | 
| @@ -912,7 +913,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 912 | 
             
                        "past_key_values": past_key_values,
         | 
| 913 | 
             
                        "position_ids": position_ids,
         | 
| 914 | 
             
                        "attention_mask": attention_mask,
         | 
| 915 | 
            -
                        "return_last_logit": True
         | 
|  | |
| 916 | 
             
                    }
         | 
| 917 |  | 
| 918 | 
             
                def forward(
         | 
| @@ -1089,6 +1091,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): | |
| 1089 | 
             
                        generation_config = self.generation_config
         | 
| 1090 | 
             
                    generation_config = copy.deepcopy(generation_config)
         | 
| 1091 | 
             
                    model_kwargs = generation_config.update(**kwargs)
         | 
|  | |
| 1092 | 
             
                    bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
         | 
| 1093 |  | 
| 1094 | 
             
                    if isinstance(eos_token_id, int):
         | 
|  | |
| 897 | 
             
                        past_key_values: Optional[torch.Tensor] = None,
         | 
| 898 | 
             
                        attention_mask: Optional[torch.Tensor] = None,
         | 
| 899 | 
             
                        position_ids: Optional[torch.Tensor] = None,
         | 
| 900 | 
            +
                        use_cache: Optional[bool] = None,
         | 
| 901 | 
             
                        is_first_forward: bool = True,
         | 
| 902 | 
             
                        **kwargs
         | 
| 903 | 
             
                ) -> dict:
         | 
|  | |
| 905 | 
             
                    if position_ids is None:
         | 
| 906 | 
             
                        position_ids = self.get_position_ids(input_ids, device=input_ids.device)
         | 
| 907 | 
             
                    if not is_first_forward:
         | 
| 908 | 
            +
                        if past_key_values is not None:
         | 
| 909 | 
             
                            position_ids = position_ids[..., -1:]
         | 
| 910 | 
             
                            input_ids = input_ids[:, -1:]
         | 
| 911 | 
             
                    return {
         | 
|  | |
| 913 | 
             
                        "past_key_values": past_key_values,
         | 
| 914 | 
             
                        "position_ids": position_ids,
         | 
| 915 | 
             
                        "attention_mask": attention_mask,
         | 
| 916 | 
            +
                        "return_last_logit": True,
         | 
| 917 | 
            +
                        "use_cache": use_cache
         | 
| 918 | 
             
                    }
         | 
| 919 |  | 
| 920 | 
             
                def forward(
         | 
|  | |
| 1091 | 
             
                        generation_config = self.generation_config
         | 
| 1092 | 
             
                    generation_config = copy.deepcopy(generation_config)
         | 
| 1093 | 
             
                    model_kwargs = generation_config.update(**kwargs)
         | 
| 1094 | 
            +
                    model_kwargs["use_cache"] = generation_config.use_cache
         | 
| 1095 | 
             
                    bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
         | 
| 1096 |  | 
| 1097 | 
             
                    if isinstance(eos_token_id, int):
         | 
