Fix: AttributeError when `input_ids` is None during multimodal LLM training
Browse filesWhen training a multimodal language model, such as MiniGPT-4, the model utilizes `inputs_embeds` instead of `input_ids`. This is because the multimodal embeddings are aligned with the LLM's text space and are concatenated with the text embeddings, rendering `input_ids` unnecessary and thus `None`.
This leads to the following error:
```
AttributeError: 'NoneType' object has no attribute 'shape'
```
This commit addresses the issue by modifying the code to handle cases where `input_ids` is None, ensuring that the model can properly process the provided `inputs_embeds` without relying on `input_ids`.
- modeling_chatglm.py +5 -4
    	
        modeling_chatglm.py
    CHANGED
    
    | @@ -771,15 +771,16 @@ class ChatGLMPreTrainedModel(PreTrainedModel): | |
| 771 | 
             
                        if padding_mask is not None and not padding_mask.all():
         | 
| 772 | 
             
                            return padding_mask
         | 
| 773 | 
             
                        return None
         | 
| 774 | 
            -
                    batch_size, seq_length = input_ids.shape
         | 
| 775 | 
            -
                     | 
|  | |
| 776 | 
             
                    full_attention_mask.tril_()
         | 
| 777 | 
             
                    past_length = 0
         | 
| 778 | 
             
                    if past_key_values:
         | 
| 779 | 
             
                        past_length = past_key_values[0][0].shape[2]
         | 
| 780 | 
             
                    if past_length:
         | 
| 781 | 
             
                        full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
         | 
| 782 | 
            -
                                                                    device= | 
| 783 | 
             
                    if padding_mask is not None:
         | 
| 784 | 
             
                        full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
         | 
| 785 | 
             
                    if not past_length and padding_mask is not None:
         | 
| @@ -872,7 +873,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): | |
| 872 | 
             
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 873 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 874 |  | 
| 875 | 
            -
                    batch_size, seq_length = input_ids.shape
         | 
| 876 |  | 
| 877 | 
             
                    if inputs_embeds is None:
         | 
| 878 | 
             
                        inputs_embeds = self.embedding(input_ids)
         | 
|  | |
| 771 | 
             
                        if padding_mask is not None and not padding_mask.all():
         | 
| 772 | 
             
                            return padding_mask
         | 
| 773 | 
             
                        return None
         | 
| 774 | 
            +
                    batch_size, seq_length = input_ids.shape if input_ids is not None else padding_mask.shape
         | 
| 775 | 
            +
                    device = input_ids.device if input_ids is not None else padding_mask.device
         | 
| 776 | 
            +
                    full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=device)
         | 
| 777 | 
             
                    full_attention_mask.tril_()
         | 
| 778 | 
             
                    past_length = 0
         | 
| 779 | 
             
                    if past_key_values:
         | 
| 780 | 
             
                        past_length = past_key_values[0][0].shape[2]
         | 
| 781 | 
             
                    if past_length:
         | 
| 782 | 
             
                        full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
         | 
| 783 | 
            +
                                                                    device=device), full_attention_mask), dim=-1)
         | 
| 784 | 
             
                    if padding_mask is not None:
         | 
| 785 | 
             
                        full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
         | 
| 786 | 
             
                    if not past_length and padding_mask is not None:
         | 
|  | |
| 873 | 
             
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 874 | 
             
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 875 |  | 
| 876 | 
            +
                    batch_size, seq_length = (input_ids.shape if input_ids is not None else inputs_embeds.shape[:2] if inputs_embeds is not None else (None, None))
         | 
| 877 |  | 
| 878 | 
             
                    if inputs_embeds is None:
         | 
| 879 | 
             
                        inputs_embeds = self.embedding(input_ids)
         | 
