Upload hunyuan.py with huggingface_hub
Browse files- hunyuan.py +1 -29
    	
        hunyuan.py
    CHANGED
    
    | @@ -41,7 +41,6 @@ from transformers.utils.import_utils import is_torch_fx_available | |
| 41 | 
             
            from transformers.generation.utils import GenerateOutput
         | 
| 42 | 
             
            from .configuration_hunyuan import HunYuanConfig
         | 
| 43 | 
             
            from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
         | 
| 44 | 
            -
            from .vit_model import NaVitForward, VitForward, Vit 
         | 
| 45 |  | 
| 46 |  | 
| 47 | 
             
            if is_flash_attn_2_available():
         | 
| @@ -363,16 +362,7 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel): | |
| 363 |  | 
| 364 | 
             
                def __init__(self, config: HunYuanConfig):
         | 
| 365 | 
             
                    super().__init__(config)
         | 
| 366 | 
            -
                     | 
| 367 | 
            -
                        if "-tp" in config.vit_type:
         | 
| 368 | 
            -
                            config.vit_type = config.vit_type.replace("-tp", "")
         | 
| 369 | 
            -
                        self.vit_type = config.vit_type
         | 
| 370 | 
            -
                        if self.vit_type not in ['NaVit', 'EvaVit']:
         | 
| 371 | 
            -
                            if config.vit_mapping_type == 'mlp':
         | 
| 372 | 
            -
                                self.vit_linear_encoder = torch.nn.Linear(config.hidden_size, config.hidden_size)
         | 
| 373 | 
            -
                        self.vit = Vit(config)
         | 
| 374 | 
            -
                    else:
         | 
| 375 | 
            -
                        self.vit = None
         | 
| 376 | 
             
                    self.config = config
         | 
| 377 | 
             
                    self.model = HunYuanModel(config)
         | 
| 378 | 
             
                    self.add_classification_head = config.add_classification_head
         | 
| @@ -643,15 +633,6 @@ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM): | |
| 643 | 
             
                    video_start_id = self.config.video_start_id
         | 
| 644 | 
             
                    video_end_id = self.config.video_end_id
         | 
| 645 |  | 
| 646 | 
            -
                    if self.vit is not None and imgs is not None:
         | 
| 647 | 
            -
                        encoder_input = self.model.embed_tokens(input_ids)
         | 
| 648 | 
            -
                        if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
         | 
| 649 | 
            -
                            inputs_embeds, input_ids = NaVitForward(input_ids, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
         | 
| 650 | 
            -
                                im_start_id, im_end_id, image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
         | 
| 651 | 
            -
                        else:
         | 
| 652 | 
            -
                            inputs_embeds, input_ids = VitForward(input_ids, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
         | 
| 653 | 
            -
                                self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
         | 
| 654 | 
            -
             | 
| 655 | 
             
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         | 
| 656 | 
             
                    output_hidden_states = (
         | 
| 657 | 
             
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| @@ -738,15 +719,6 @@ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM): | |
| 738 | 
             
                    if "inputs_embeds" in kwargs:
         | 
| 739 | 
             
                        raise NotImplementedError("`inputs_embeds` is not supported")
         | 
| 740 |  | 
| 741 | 
            -
                    if self.vit is not None:
         | 
| 742 | 
            -
                        encoder_input = self.model.embed_tokens(inputs)
         | 
| 743 | 
            -
                        if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
         | 
| 744 | 
            -
                            inputs_embeds, input_ids = NaVitForward(inputs, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
         | 
| 745 | 
            -
                                self.config.im_start_id, self.config.im_end_id, self.config.image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
         | 
| 746 | 
            -
                        else:
         | 
| 747 | 
            -
                            inputs_embeds, input_ids = VitForward(inputs, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
         | 
| 748 | 
            -
                                self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
         | 
| 749 | 
            -
             | 
| 750 | 
             
                    return super().generate(
         | 
| 751 | 
             
                        inputs=input_ids,
         | 
| 752 | 
             
                        position_ids=position_ids,
         | 
|  | |
| 41 | 
             
            from transformers.generation.utils import GenerateOutput
         | 
| 42 | 
             
            from .configuration_hunyuan import HunYuanConfig
         | 
| 43 | 
             
            from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
         | 
|  | |
| 44 |  | 
| 45 |  | 
| 46 | 
             
            if is_flash_attn_2_available():
         | 
|  | |
| 362 |  | 
| 363 | 
             
                def __init__(self, config: HunYuanConfig):
         | 
| 364 | 
             
                    super().__init__(config)
         | 
| 365 | 
            +
                    
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 366 | 
             
                    self.config = config
         | 
| 367 | 
             
                    self.model = HunYuanModel(config)
         | 
| 368 | 
             
                    self.add_classification_head = config.add_classification_head
         | 
|  | |
| 633 | 
             
                    video_start_id = self.config.video_start_id
         | 
| 634 | 
             
                    video_end_id = self.config.video_end_id
         | 
| 635 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 636 | 
             
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         | 
| 637 | 
             
                    output_hidden_states = (
         | 
| 638 | 
             
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
|  | |
| 719 | 
             
                    if "inputs_embeds" in kwargs:
         | 
| 720 | 
             
                        raise NotImplementedError("`inputs_embeds` is not supported")
         | 
| 721 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 722 | 
             
                    return super().generate(
         | 
| 723 | 
             
                        inputs=input_ids,
         | 
| 724 | 
             
                        position_ids=position_ids,
         | 

