PenPaperKeyCode commited on
Commit
3169f6c
·
0 Parent(s):
Files changed (45) hide show
  1. .gitattributes +8 -0
  2. LICENSE +122 -0
  3. README.md +459 -0
  4. chat_template.jinja +173 -0
  5. config.json +320 -0
  6. configuration_hyperclovax.py +228 -0
  7. configuration_vlm.py +169 -0
  8. cosyvoice.py +516 -0
  9. decoder/audio/NCCosybigvganDecoder.mar +3 -0
  10. decoder/audio/NCZSCosybigvganDecoder.mar +3 -0
  11. decoder/vision/model_index.json +25 -0
  12. decoder/vision/scheduler/scheduler_config.json +18 -0
  13. decoder/vision/token_embedder/config.json +7 -0
  14. decoder/vision/token_embedder/diffusion_pytorch_model.safetensors +3 -0
  15. decoder/vision/transformer/config.json +21 -0
  16. decoder/vision/transformer/diffusion_pytorch_model.safetensors +3 -0
  17. decoder/vision/transformer2/config.json +21 -0
  18. decoder/vision/transformer2/diffusion_pytorch_model.safetensors +3 -0
  19. decoder/vision/vae/config.json +38 -0
  20. decoder/vision/vae/diffusion_pytorch_model.safetensors +3 -0
  21. generation_config.json +6 -0
  22. mambamia_videoaudio_compressor.py +803 -0
  23. model-00001-of-00010.safetensors +3 -0
  24. model-00002-of-00010.safetensors +3 -0
  25. model-00003-of-00010.safetensors +3 -0
  26. model-00004-of-00010.safetensors +3 -0
  27. model-00005-of-00010.safetensors +3 -0
  28. model-00006-of-00010.safetensors +3 -0
  29. model-00007-of-00010.safetensors +3 -0
  30. model-00008-of-00010.safetensors +3 -0
  31. model-00009-of-00010.safetensors +3 -0
  32. model-00010-of-00010.safetensors +3 -0
  33. model.safetensors.index.json +0 -0
  34. modeling_hyperclovax.py +1866 -0
  35. modeling_vlm.py +0 -0
  36. patch_vuvlm.py +1085 -0
  37. preprocessor.py +0 -0
  38. preprocessor_config.json +32 -0
  39. processing_vlm.py +963 -0
  40. processor_config.json +6 -0
  41. special_tokens_map.json +30 -0
  42. ta_tok.py +379 -0
  43. tokenizer.json +3 -0
  44. tokenizer_config.json +3 -0
  45. video_preprocessor_config.json +89 -0
.gitattributes ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
+ *.pt filter=lfs diff=lfs merge=lfs -text
4
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
5
+ decoder/audio/NCCosybigvganDecoder.mar filter=lfs diff=lfs merge=lfs -text
6
+ decoder/audio/NCZSCosybigvganDecoder.mar filter=lfs diff=lfs merge=lfs -text
7
+ tokenizer_config.json filter=lfs diff=lfs merge=lfs -text
8
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HyperCLOVA X SEED 8B Omni Model License Agreement
2
+
3
+ Model Release Date: December 29, 2025
4
+
5
+ This HyperCLOVA X SEED 8B Omni Model License Agreement (the “Agreement”) is a legal agreement between you and NAVER Corporation (“Naver Corp.”) and NAVER Cloud Corporation (“Naver Cloud Corp.”) (Naver Corp. and Naver Cloud Corp. are collectively referred to as “NAVER”) and governs your use of the Models that NAVER provides to You under this Agreement.
6
+
7
+ NAVER Corp., as the holder of the intellectual property of the Model, and its affiliate, NAVER Cloud Corp., as the exclusive business operator of HyperCLOVA X, enter into this Agreement with you. NAVER and you are each a “party” and collectively the “parties.”
8
+
9
+ By using, reproducing, modifying, distributing, performing or displaying any portion or element of the Model or Derivative Model, or otherwise accepting the terms of this Agreement, you agree to be bound by this Agreement. You represent to us that you are lawfully able to enter into contracts, and if you are entering into this Agreement for an entity, that you have legal authority to bind that entity.
10
+
11
+ 1. Definitions.
12
+
13
+ 1.1. "Affiliate” means any entity directly or indirectly controlling, controlled by or under common control with either party, where “control” means the possession, directly or indirectly, of the power to independently direct or cause the direction of the management and policies of an entity, whether through ownership of more than fifty percent (50%) of the stock or other equity interests entitled to vote for representation on its board of directors, or body performing similar functions, by contract or otherwise.
14
+
15
+ 1.2. “Derivative Model” means all (i) modifications to the Model, (ii) works based on the Model, or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of the Model, to that model in order to cause that model to perform similarly to the Model, including distillation methods that use intermediate data representations or methods based on the generation of synthetic data Outputs by the Model for training that Model. For clarity, Outputs are not deemed Derivative Model.
16
+
17
+ 1.3. “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
18
+
19
+ 1.4. “Model” means the foundational large language models and software and algorithms, including machine-learning model code and trained model weights distributed by NAVER.
20
+
21
+
22
+ 1.5. “Output” means the information content output of the Model or a Derivative Model that results from operating or otherwise using the Model or Derivative Model.
23
+
24
+ 2. Conditions for Use, License Grant and Restrictions
25
+
26
+ 2.1. Conditions for Use. The Model and any Derivative Model are subject to the terms of this Agreement and govern your use. If You institute copyright or patent litigation against any entity (including a crossclaim or counterclaim in a lawsuit) alleging that the Model or Derivative Model constitutes direct or contributory copyright or patent infringement, then any license granted to you under this Agreement for that Model or Derivative Model will terminate as of the date such litigation is filed. NAVER may update this Agreement to comply with legal and regulatory requirements any time and You agree to either comply with any updated license or cease your copying, use, and distribution of the Model and any Derivative Model.
27
+
28
+ 2.2. License Grant. Subject to the terms and conditions of this Agreement, NAVER hereby grants to you a non-exclusive, worldwide, non-transferable, revocable and royalty-free limited license under NAVER’s intellectual property or other rights owned by NAVER embodied in the Model to access, download, install, copy, use, reproduce, distribute, create derivative works of, and make modifications to the Model.
29
+
30
+ 2.3. Prohibited Use Policy. NAVER is committed to ensuring safety trust, and transparency in the development and use of AI technologies. Accordingly, your use of the Model and any Derivative Models is subject to the following conditions:
31
+ (i) You must ensure that any product or service you develop, use, offer as a service, or distribute complies with all applicable laws and regulations, and is operated appropriately for the relevant industry or use case.
32
+ (ii) You must comply with the Acceptable Use Policy applicable to the Model and any Derivative Models, which is attached hereto as Addendum A and incorporated by reference into this Agreement.
33
+ (iii) NAVER expressly prohibits the use of its products or services for any purpose in violation of applicable law and regulation, including but not limited to:
34
+ (a) illegal surveillance,
35
+ (b) illegal collection or processing of biometric information without the consent of the subject which is required under applicable law, or
36
+ (c) illegal harassment, abuse, threatening or bullying of individuals or groups of individuals or intentionally misleading or deceiving others.
37
+ (iv) You must take reasonable measures to address unintended bias and to mitigate harm to others, including underrepresented or vulnerable groups.
38
+
39
+
40
+ 3. Redistribution.
41
+
42
+ 3.1. You may reproduce, distribute or make available the Model or Derivative Models thereof, or a product or service (including another AI model) that contains any of them, if you meet all of the following conditions: you must (i) include the Prohibited Use Policy referenced in Section 2.3. as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of the Model or Derivative Model and you must provide notice to subsequence users you distribute to the Model or Derivative Models are subject to the use restrictions in Section 2.3., (ii) provide all third party recipients of the Model or Derivative Models a copy of this Agreement, (iii) cause any modified files to carry prominent notices stating that you modified the files; (iv) include the following attribution notice within a “Notice” text file distributed as part of such copies: “HyperCLOVA X SEED 8B Omni Model is licensed under the HyperCLOVA X SEED 8B Omni Model License Agreement, Copyright © NAVER Corp. All Rights Reserved.”, and (v) prominently display “Powered by HyperCLOVA X” on a related website, user interface, blogpost, about page, or product documentation. If you use the Model or any Outputs of the Model to create, train, fine tune, or otherwise improve an AI model, which is distributed or made available, you shall also include “HyperCLOVA X” at the beginning of any such AI model name.
43
+ 3.2. You may add your own copyright statement to your modifications and, except as set forth in this Section, may provide additional or different license terms and conditions for use, reproduction, or distribution of your modifications, or for any such Derivative Models as a whole, provided your use, reproduction, and distribution of the Model or Derivative Models otherwise comply with the terms and conditions stated in this Agreement. Any additional or different terms and conditions you impose must not conflict with the terms of this Agreement.
44
+
45
+ 4. Additional Commercial Terms. If (i) as of the Model Release Date, the monthly active users of the products or services made available by or for Licensee, or Licensee’s Affiliates, is greater than 10 million monthly active users in the preceding calendar month, or (ii) the Licensee or its Affiliate distributes or makes available any product or service, which is substantially similar to or directly competes with any product and service provided by NAVER, then the Licensee must request a license from NAVER. Such a license may be granted by NAVER at its sole discretion, and the Licensee is not authorized to exercise any rights under this Agreement unless and until NAVER expressly grants you such rights.
46
+
47
+ 5. Generated Output. NAVER claims no rights in Outputs you generate using the Model. You and your use are solely responsible for Outputs and their subsequent uses.
48
+
49
+ 6. DISCLAIMER OF WARRANTY. UNLESS REQUIRED BY APPLICABLE LAW, THE MODEL AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OR ANY KIND, AND NAVER DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE MODEL, DERIVATIVE MODELS, OUTPUTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE MODEL AND ANY OUTPUTS AND RESULTS AND YOUR EXERCISE OF PERMISSION UNDER THIS AGREEMENT.
50
+
51
+ 7. LIMITATION OF LIABILITY. IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE, UNLESS REQUIRED BY APPLICABLE LAW (SUCH AS IN CASES OF DELIBERATE AND GROSSLY NEGLIGENT ACTS), WILL NAVER BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY, OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND, ARISING FROM OR RELATED TO THIS AGREEMENT, OR RESULTING FROM THE USE OR INABILITY TO USE THE MODEL, DERIVATIVE MODELS OR, OUTPUTS (INCLUDING, BUT NOT LIMITED TO, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGES, COMPUTER FAILURE OR MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES), EVEN IF NAVER HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
52
+
53
+ 8. Indemnity. You will indemnify and hold harmless NAVER from and against any claim by any third party arising out of or related to your use or distribution of the Model, Derivative Model or Outputs.
54
+
55
+ 9. Intellectual Property.
56
+
57
+ 9.1. This Agreement does not grant permission to use the trade names, trademarks, service marks, or product names of NAVER, except as required for reasonable and customary use in describing the origin of the Model and reproducing the content of the “Notice” text file.
58
+
59
+ 9.2. NAVER Corp. owns the Model and any Derivative Model created by NAVER Corp. Except as expressively granted in this Agreement, NAVER Corp. reserves all rights, interests and remedies in connection with the Model and Derivative Model created by NAVER Corp. and no other license or right is granted to you by implication, estoppel or otherwise. Subject to NAVER Corp.’s ownership of the Model and any Derivative Model made by or for NAVER Corp., with respect to any derivative works and modifications of the Model that are made by you, as between you and NAVER Corp., you are and will be the owner of such derivative works and modifications.
60
+
61
+ 10. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Model and will continue in full force and effect until terminated in accordance with the terms and conditions of this Agreement. NAVER may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Model and Derivative Model. Section 5, 6, 7 and 10 shall survive the termination of this Agreement.
62
+
63
+ 11. Governing Law and Jurisdiction.
64
+
65
+ 11.1. This Agreement will be governed by and construed in accordance with the laws of the Republic of Korea, without regard to its conflicts of laws principles.
66
+
67
+ 11.2. Any disputes, controversies, or claims arising out of or relating to this Agreement, including its existence, validity, interpretation, performance, breach, or termination, shall be referred to and finally resolved by arbitration administered by the Korean Commercial Arbitration Board (KCAB) in accordance with the International Arbitration Rules of the Korean Commercial Arbitration Board in force at the time of the commencement of the arbitration. The seat of arbitration shall be Seoul, Republic of Korea. The tribunal shall consist of one arbitrator. The language of the arbitration shall be English. Either party may seek interim or provisional relief from a court of competent jurisdiction and doing so shall not be considered a waiver of any provision in this section. The arbitral tribunal also has the authority to issue orders for interim or provisional relief.
68
+
69
+ 12. Modifications. NAVER reserves the right to modify or amend this Agreement at any time, in its sole discretion. Any modifications will be effective upon posting the updated Agreement on our website or through other means of communication. You are responsible for reviewing the Agreement periodically for changes.
70
+
71
+ 13. No Waiver. NAVER will not be treated as having waived any rights by not exercising (or delaying the exercise of) any rights under this Agreement.
72
+
73
+
74
+
75
+ Addendum A – Acceptable Use Policy
76
+
77
+ NAVER is committed to promoting safe and responsible use of its AI technologies, including the HyperCLOVA X SEED 8B Omni Model (the “Model”). By accessing or using the Model and Derivative Model (Defined in the Model License Agreement) (the Model and Derivative Model are collectively referred to as the “Models”), you agree to this Acceptable Use Policy (“Policy”).
78
+
79
+ We want everyone to use the Models safely, legally, and ethically. You agree that you will not use, or allow others to use, the Models to:
80
+
81
+ 1. Violate applicable laws or the rights of others, including by:
82
+ a. Engaging in, promoting, contributing to, encouraging, planning, inciting, or furthering illegal or unlawful activity or content, such as:
83
+  Violence or terrorism
84
+  Exploitation or harm to children, including the creation or dissemination of child exploitative content
85
+  Human trafficking, exploitation, or sexual violence
86
+  The unlawful distribution of obscene or harmful material to minors, or failure to apply legally required age restrictions
87
+  Sexual solicitation or sexually exploitative behavior
88
+  Any other criminal activity
89
+ b. Engaging in, promoting, inciting, or facilitating the harassment, abuse, threatening, or bullying of individuals or groups
90
+ c. Engaging in, promoting, inciting, or facilitating discrimination or other unlawful or harmful conduct in the provision of employment, credit, housing, or access to essential goods and services
91
+ d. Providing unauthorized or unlicensed professional services, including but not limited to financial, legal, medical/health, or related services
92
+ e. Collecting, processing, disclosing, generating, or inferring private or sensitive personal information, including identity, health, or demographic data, unless lawfully permitted under applicable laws
93
+ f. Infringing, misappropriating, or otherwise violating third-party rights, including through the generation or use of outputs derived from the Models
94
+ g. Creating, generating, or facilitating malicious code, malware, or computer viruses, or interfering with the functioning, security, or integrity of a website, application, or system
95
+ h. Intentionally bypassing or disabling usage restrictions, safety measures, or access controls imposed by NAVER
96
+
97
+ 2. Engage in or promote use cases that may pose a risk of death, bodily harm, or significant safety hazard to individuals, including use of the Models in connection with:
98
+ a. Military, warfare, nuclear technology or espionage
99
+ b. The development or distribution of firearms or illegal weapons
100
+ c. Illegal drugs or regulated controlled substances
101
+ d. Operation of critical infrastructure, transportation systems, or heavy machinery
102
+ e. Content promoting self-harm, including suicide, or eating disorders
103
+ f. Any other use intended to incite or cause physical harm
104
+
105
+ 3. Intentionally deceive or mislead others, including by:
106
+ a. Generating, promoting, or disseminating fraudulent or misleading content
107
+ b. Creating or sharing defamatory content
108
+ c. Generating or distributing spam
109
+ d. Impersonating another individual or entity without proper authorization
110
+ e. Representing Model output as human-generated
111
+ f. Generating or enabling fake online engagement, such as fake reviews or fake users
112
+
113
+ 4. Fail to disclose to end users any known risks or limitations of an AI system that incorporates the Models.
114
+
115
+ 5. Use the Models in conjunction with third-party tools, models, or software designed to generate unlawful content or conduct, or falsely represent outputs from such tools as associated with NAVER or HyperCLOVA X.
116
+
117
+ If you become aware of a violation of this Policy, a bug, or any behavior that could result in a breach of this Policy, please report it to us:
118
+
119
+ Reporting risky outputs: [email protected]
120
+ Reporting policy violations or unauthorized use: [email protected]
121
+
122
+
README.md ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: hyperclovax
4
+ license_link: LICENSE
5
+ library_name: transformers
6
+ ---
7
+
8
+ ![image](https://cdn-uploads.huggingface.co/production/uploads/64383d54c5a91b84ece18d62/3gaPG3_F4Fxn-SOZWrmfU.png)
9
+
10
+ # Overview
11
+ HyperCLOVA X SEED 8B Omni is a unified multimodal model that brings text, vision, and speech together, based on an auto-regressive Transformer architecture, enabling consistent multimodal understanding and generation. SEED 8B Omni aligns textual, visual, and audio representations in a shared semantic space and supports bidirectional interactions across modalities, including established text capabilities as well as vision–language QA, text-to-image generation and editing, speech recognition and translation, and text-to-speech, within a 32K context window. As an early pathfinding milestone of HyperCLOVA X toward **Any-to-Any-Korean-First** intelligence, SEED 8B Omni serves as a practical exploration of unified multimodal modeling and provides a reference point for future development and scaling.
12
+
13
+ ---
14
+
15
+ # Basic Information
16
+
17
+ - **Architecture** : Transformer-based omni-model architecture (Dense Model)
18
+ - **Parameters** : 8B
19
+ - **Input Format**: Text/Image/Video/Audio(Speech)
20
+ - **Output Format**: Text/Image/Audio(Speech)
21
+ - **Context Length** : 32K
22
+ - **Knowledge Cutoff**: May 2025
23
+
24
+ ---
25
+
26
+ # Benchmarks
27
+ ![테크니컬 리포트 05_2@2x](https://cdn-uploads.huggingface.co/production/uploads/646acf46086023e36edce4c4/x1IvD9Rt_NK71CklecpN2.png)
28
+
29
+
30
+ - **Text-to-Text** : MMLU-Pro, GSM8K, KMMLU-Pro, HAERAE 1.0
31
+ - **Vision-to-Text** :SEED-IMG, AI2D, K-MMBench
32
+ - **Text-to-Vision**: GenEval, ImgEdit
33
+ - **Audio-to-Text**: Librispeech, Ksponspeech
34
+ - **Audio-to-Audio**:Fleurs en2ko, Fleurs ko2en
35
+
36
+ ---
37
+
38
+ # Examples
39
+ ## Text-to-Image Generation
40
+ ![hf_img01](https://cdn-uploads.huggingface.co/production/uploads/64383d54c5a91b84ece18d62/6fRekMbt_9ab5I80GTkdG.png)
41
+ ## Text-based Image Editing
42
+ ![hf_img02](https://cdn-uploads.huggingface.co/production/uploads/64383d54c5a91b84ece18d62/aoecU357A0fVvR8uerozh.png)
43
+ ![hf_img03](https://cdn-uploads.huggingface.co/production/uploads/64383d54c5a91b84ece18d62/0fpcq--rj1kqPa9m8DYgt.png)
44
+ ![hf_img04](https://cdn-uploads.huggingface.co/production/uploads/64383d54c5a91b84ece18d62/Z24JUQZSmeaVNrhDMYG6K.png)
45
+
46
+ ---
47
+
48
+ # Inference
49
+
50
+ We provide [OmniServe](https://github.com/NAVER-Cloud-HyperCLOVA-X/OmniServe), a production-ready multimodal inference system with OpenAI-compatible API.
51
+
52
+ ## Capabilities
53
+
54
+ - **Inputs**: Text, Image, Audio, Video
55
+ - **Outputs**: Text, Image, Audio (no video generation)
56
+
57
+ ## Requirements
58
+
59
+ - 4x NVIDIA A100 80GB
60
+ - Docker & Docker Compose
61
+ - NVIDIA Driver 525+, CUDA 12.1+
62
+ - S3-compatible storage (for image/audio output)
63
+
64
+ ## Installation
65
+
66
+ ```bash
67
+ # Clone OmniServe
68
+ git clone https://github.com/NAVER-Cloud-HyperCLOVA-X/OmniServe.git
69
+ cd OmniServe
70
+
71
+ # Install dependencies
72
+ pip install huggingface_hub safetensors torch openai easydict
73
+
74
+ # Download model (~16GB)
75
+ huggingface-cli download naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B \
76
+ --local-dir ./models/HyperCLOVAX-SEED-Omni-8B
77
+
78
+ # Convert model to component format
79
+ python convert_model.py \
80
+ --input ./models/HyperCLOVAX-SEED-Omni-8B \
81
+ --output ./track_b \
82
+ --track b
83
+
84
+ # Configure environment
85
+ cp .env.example .env
86
+ # Edit .env with model paths and S3 credentials
87
+
88
+ # Build and run (Track B only - OMNI model)
89
+ docker compose --profile track-b build
90
+ docker compose --profile track-b up -d
91
+
92
+ # Wait for model loading (~5 minutes)
93
+ docker compose logs -f omni
94
+
95
+ # Note: To run both VLM and OMNI models together:
96
+ # docker compose --profile track-a --profile track-b up -d
97
+ ```
98
+
99
+ ## Basic Usage
100
+
101
+ ```python
102
+ from openai import OpenAI
103
+
104
+ client = OpenAI(
105
+ base_url="http://localhost:8000/b/v1",
106
+ api_key="not-needed"
107
+ )
108
+
109
+ # Image understanding
110
+ response = client.chat.completions.create(
111
+ model="track_b_model",
112
+ messages=[
113
+ {
114
+ "role": "user",
115
+ "content": [
116
+ {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
117
+ {"type": "text", "text": "What is in this image?"}
118
+ ]
119
+ }
120
+ ],
121
+ max_tokens=256,
122
+ extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
123
+ )
124
+
125
+ print(response.choices[0].message.content)
126
+ ```
127
+
128
+ ## More Examples
129
+
130
+ <details>
131
+ <summary>Text to Image</summary>
132
+
133
+ ```python
134
+ import json
135
+
136
+ SYSTEM_PROMPT = """You are an AI assistant that generates images. When asked to draw or create an image, you MUST use the t2i_model_generation tool to generate the image. Always respond by calling the tool."""
137
+
138
+ response = client.chat.completions.create(
139
+ model="track_b_model",
140
+ messages=[
141
+ {"role": "system", "content": SYSTEM_PROMPT},
142
+ {"role": "user", "content": "Draw a sunset over mountains"}
143
+ ],
144
+ tools=[{
145
+ "type": "function",
146
+ "function": {
147
+ "name": "t2i_model_generation",
148
+ "description": "Generates an RGB image based on the provided discrete image representation.",
149
+ "parameters": {
150
+ "type": "object",
151
+ "required": ["discrete_image_token"],
152
+ "properties": {
153
+ "discrete_image_token": {
154
+ "type": "string",
155
+ "description": "A serialized string of discrete vision tokens, encapsulated by special tokens. The format must be strictly followed: <|discrete_image_start|><|vision_ratio_4:3|><|vision_token|><|visionaaaaa|><|visionbbbbb|>... <|visionzzzzz|><|vision_eol|><|vision_eof|><|discrete_image_end|>.",
156
+ "minLength": 1
157
+ }
158
+ }
159
+ }
160
+ }
161
+ }],
162
+ max_tokens=7000,
163
+ extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
164
+ )
165
+
166
+ if response.choices[0].message.tool_calls:
167
+ args = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
168
+ print(f"Generated image: {args['discrete_image_token']}")
169
+ ```
170
+
171
+ </details>
172
+
173
+ <details>
174
+ <summary>Text to Audio</summary>
175
+
176
+ ```python
177
+ import base64
178
+
179
+ # Prompt should explicitly request speech/audio output
180
+ response = client.chat.completions.create(
181
+ model="track_b_model",
182
+ messages=[{
183
+ "role": "user",
184
+ "content": "Read this text aloud in a cheerful female voice:\nHello! How are you today?"
185
+ }],
186
+ max_tokens=1000,
187
+ extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
188
+ )
189
+
190
+ if response.choices[0].message.audio:
191
+ audio_url = base64.b64decode(response.choices[0].message.audio.data).decode()
192
+ print(f"Generated audio: {audio_url}")
193
+ ```
194
+
195
+ </details>
196
+
197
+ <details>
198
+ <summary>Audio Input</summary>
199
+
200
+ ```python
201
+ import base64
202
+
203
+ audio_url = "https://example.com/audio.mp3"
204
+ audio_data = base64.b64encode(audio_url.encode()).decode()
205
+
206
+ response = client.chat.completions.create(
207
+ model="track_b_model",
208
+ messages=[
209
+ {
210
+ "role": "user",
211
+ "content": [
212
+ {"type": "input_audio", "input_audio": {"data": audio_data, "format": "mp3"}},
213
+ {"type": "text", "text": "What is being said?"}
214
+ ]
215
+ }
216
+ ],
217
+ max_tokens=256,
218
+ extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
219
+ )
220
+
221
+ print(response.choices[0].message.content)
222
+ ```
223
+
224
+ </details>
225
+
226
+ <details>
227
+ <summary>Video Input</summary>
228
+
229
+ ```python
230
+ response = client.chat.completions.create(
231
+ model="track_b_model",
232
+ messages=[
233
+ {
234
+ "role": "user",
235
+ "content": [
236
+ {"type": "image_url", "image_url": {"url": "https://example.com/video.mp4"}},
237
+ {"type": "text", "text": "Describe this video."}
238
+ ]
239
+ }
240
+ ],
241
+ max_tokens=512,
242
+ extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
243
+ )
244
+
245
+ print(response.choices[0].message.content)
246
+ ```
247
+
248
+ </details>
249
+
250
+ <details>
251
+ <summary>Image to Image</summary>
252
+
253
+ ```python
254
+ import json
255
+
256
+ SYSTEM_PROMPT = """You are an AI assistant that transforms images. When asked to transform, edit, or stylize an image, you MUST use the t2i_model_generation tool to generate the new image. Always respond by calling the tool."""
257
+
258
+ response = client.chat.completions.create(
259
+ model="track_b_model",
260
+ messages=[
261
+ {"role": "system", "content": SYSTEM_PROMPT},
262
+ {
263
+ "role": "user",
264
+ "content": [
265
+ {"type": "image_url", "image_url": {"url": "https://example.com/photo.jpg"}},
266
+ {"type": "text", "text": "Transform to watercolor style"}
267
+ ]
268
+ }
269
+ ],
270
+ tools=[{
271
+ "type": "function",
272
+ "function": {
273
+ "name": "t2i_model_generation",
274
+ "description": "Generates an RGB image based on the provided discrete image representation.",
275
+ "parameters": {
276
+ "type": "object",
277
+ "required": ["discrete_image_token"],
278
+ "properties": {
279
+ "discrete_image_token": {
280
+ "type": "string",
281
+ "description": "A serialized string of discrete vision tokens, encapsulated by special tokens. The format must be strictly followed: <|discrete_image_start|><|vision_ratio_4:3|><|vision_token|><|visionaaaaa|><|visionbbbbb|>... <|visionzzzzz|><|vision_eol|><|vision_eof|><|discrete_image_end|>.",
282
+ "minLength": 1
283
+ }
284
+ }
285
+ }
286
+ }
287
+ }],
288
+ max_tokens=7000,
289
+ extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
290
+ )
291
+
292
+ if response.choices[0].message.tool_calls:
293
+ args = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
294
+ print(f"Generated image: {args['discrete_image_token']}")
295
+ ```
296
+
297
+ </details>
298
+
299
+ <details>
300
+ <summary>Audio to Audio</summary>
301
+
302
+ ```python
303
+ import base64
304
+
305
+ # Input audio (URL encoded as base64)
306
+ audio_url = "https://example.com/input.mp3"
307
+ audio_data = base64.b64encode(audio_url.encode()).decode()
308
+
309
+ response = client.chat.completions.create(
310
+ model="track_b_model",
311
+ messages=[
312
+ {
313
+ "role": "user",
314
+ "content": [
315
+ {"type": "input_audio", "input_audio": {"data": audio_data, "format": "mp3"}},
316
+ {"type": "text", "text": "Listen to this and respond with speech"}
317
+ ]
318
+ }
319
+ ],
320
+ max_tokens=2000,
321
+ extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
322
+ )
323
+
324
+ if response.choices[0].message.audio:
325
+ audio_url = base64.b64decode(response.choices[0].message.audio.data).decode()
326
+ print(f"Generated audio: {audio_url}")
327
+ ```
328
+
329
+ </details>
330
+
331
+ <details>
332
+ <summary>Using curl</summary>
333
+
334
+ ```bash
335
+ # Image understanding
336
+ curl -X POST http://localhost:8000/b/v1/chat/completions \
337
+ -H "Content-Type: application/json" \
338
+ -d '{
339
+ "model": "track_b_model",
340
+ "messages": [{"role": "user", "content": [
341
+ {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
342
+ {"type": "text", "text": "Describe this image."}
343
+ ]}],
344
+ "max_tokens": 256,
345
+ "extra_body": {"chat_template_kwargs": {"skip_reasoning": true}}
346
+ }'
347
+
348
+ # Text to audio
349
+ curl -X POST http://localhost:8000/b/v1/chat/completions \
350
+ -H "Content-Type: application/json" \
351
+ -d '{
352
+ "model": "track_b_model",
353
+ "messages": [{"role": "user", "content": "Say hello"}],
354
+ "max_tokens": 1000,
355
+ "extra_body": {"chat_template_kwargs": {"skip_reasoning": true}}
356
+ }'
357
+ ```
358
+
359
+ </details>
360
+
361
+
362
+ ## Architecture
363
+
364
+ ```
365
+ User Request
366
+ (Image/Audio/Video/Text)
367
+
368
+
369
+ ┌─────────────────────────────────────────────────────────────────────────┐
370
+ │ OmniServe │
371
+ │ POST /b/v1/chat/completions │
372
+ │ │
373
+ │ ┌──────────────────────────────────────────────────────────────────┐ │
374
+ │ │ [1] INPUT ENCODING │ │
375
+ │ │ │ │
376
+ │ │ ┌─────────────────┐ ┌─────────────────┐ │ │
377
+ │ │ │ Vision Encoder │ │ Audio Encoder │ │ │
378
+ │ │ └────────┬────────┘ └────────┬────────┘ │ │
379
+ │ │ │ │ │ │
380
+ │ │ └────────────┬────────────────────┘ │ │
381
+ │ │ │ embeddings │ │
382
+ │ └──────────────────────────┼───────────────────────────────────────┘ │
383
+ │ ▼ │
384
+ │ ┌──────────────┐ │
385
+ │ │ LLM (8B) │◀──── text │
386
+ │ └──────┬───────┘ │
387
+ │ │ │
388
+ │ ┌─────────────────────────┼────────────────────────────────────────┐ │
389
+ │ │ [2] OUTPUT DECODING │ │
390
+ │ │ │ │ │
391
+ │ │ ┌──────────────┼──────────────┐ │ │
392
+ │ │ ▼ ▼ ▼ │ │
393
+ │ │ ┌───────────┐ ┌───────────┐ ┌───────────┐ │ │
394
+ │ │ │ Text │ │ Vision │ │ Audio │ │ │
395
+ │ │ │ │ │ Decoder │ │ Decoder │ │ │
396
+ │ │ └───────────┘ └─────┬─────┘ └─────┬─────┘ │ │
397
+ │ │ │ │ │ │
398
+ │ │ ▼ ▼ │ │
399
+ │ │ Image URL Audio URL │ │
400
+ │ │ (S3) (S3) │ │
401
+ │ └──────────────────────────────────────────────────────────────────┘ │
402
+ │ │
403
+ └─────────────────────────────────────────────────────────────────────────┘
404
+
405
+
406
+ Response
407
+ (Text / Image URL / Audio URL)
408
+ ```
409
+
410
+ ## Hardware Requirements
411
+
412
+ | Component | GPU | VRAM |
413
+ |-----------|-----|------|
414
+ | Vision Encoder | 1x | ~8GB |
415
+ | Audio Encoder | (shared) | ~4GB |
416
+ | LLM (8B) | 1x | ~16GB |
417
+ | Vision Decoder | 1x | ~16GB |
418
+ | Audio Decoder | (shared) | ~4GB |
419
+ | **Total** | **3x** | **~48GB** |
420
+
421
+ ## Key Parameters
422
+
423
+ | Parameter | Description | Default |
424
+ |-----------|-------------|---------|
425
+ | `chat_template_kwargs.skip_reasoning` | Skip reasoning | `true` |
426
+ | `max_tokens` | Max output tokens | - |
427
+ | `temperature` | Sampling temperature | 0.7 |
428
+ | `tools` | Required for image generation | - |
429
+
430
+ ## S3 Configuration
431
+
432
+ Required for image/audio generation:
433
+
434
+ ```bash
435
+ NCP_S3_ENDPOINT=https://your-s3-endpoint.com
436
+ NCP_S3_ACCESS_KEY=your-access-key
437
+ NCP_S3_SECRET_KEY=your-secret-key
438
+ NCP_S3_BUCKET_NAME=your-bucket-name
439
+ ```
440
+
441
+ For more details, see [OmniServe documentation](https://github.com/NAVER-Cloud-HyperCLOVA-X/OmniServe).
442
+
443
+
444
+ ---
445
+
446
+ # Citation
447
+ TBU (Technical Report)
448
+
449
+ ---
450
+
451
+ # Questions
452
+ For any other questions, please feel free to contact us at [email protected].
453
+
454
+
455
+ ---
456
+
457
+
458
+ # License
459
+ The model is licensed under [HyperCLOVA X SEED 8B Omni Model License Agreement](./LICENSE)
chat_template.jinja ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- set ns_img = namespace(count=0) %}
2
+ {%- set ns_aud = namespace(count=0) %}
3
+ {%- set ns_vid = namespace(count=0) %}
4
+ {%- if tools %}
5
+ {{- '<|im_start|>system\n' }}
6
+ {%- if messages[0].role == 'system' %}
7
+ {%- if messages[0].content is string %}
8
+ {{- messages[0].content + '\n\n' }}
9
+ {%- elif messages[0].content is sequence %}
10
+ {%- for part in messages[0].content %}
11
+ {%- if part.type == 'text' %}
12
+ {{- part.text }}
13
+ {%- endif %}
14
+ {%- endfor %}
15
+ {{- '\n\n' }}
16
+ {%- endif %}
17
+ {%- endif %}
18
+ {{- '# Tools\n\n' }}
19
+ {{- 'You may call one or more functions to assist with the user query.\n\n' }}
20
+ {{- 'You are provided with function signatures within <tools></tools> XML tags:\n' }}
21
+ {{- '<tools>\n' }}
22
+ {%- for tool in tools %}
23
+ {{- tool | tojson }}
24
+ {%- endfor %}
25
+ {{- '\n</tools>\n\n' }}
26
+ {{- 'For each function call, output the function name and arguments within the following XML format:\n' }}
27
+ {{- '<tool_call>{function-name}\n' }}
28
+ {{- '<arg_key>{arg-key-1}</arg_key>\n' }}
29
+ {{- '<arg_value>{arg-value-1}</arg_value>\n' }}
30
+ {{- '<arg_key>{arg-key-2}</arg_key>\n' }}
31
+ {{- '<arg_value>{arg-value-2}</arg_value>\n' }}
32
+ {{- '...\n' }}
33
+ {{- '</tool_call><|im_end|>\n' }}
34
+ {%- else %}
35
+ {%- if messages[0].role == 'system' %}
36
+ {{- '<|im_start|>system\n' }}
37
+ {%- if messages[0].content is string %}
38
+ {{- messages[0].content }}
39
+ {%- elif messages[0].content is sequence %}
40
+ {%- for part in messages[0].content %}
41
+ {%- if part.type == 'text' %}
42
+ {{- part.text }}
43
+ {%- endif %}
44
+ {%- endfor %}
45
+ {%- endif %}
46
+ {{- '<|im_end|>\n' }}
47
+ {%- endif %}
48
+ {%- endif %}
49
+ {%- set ns = namespace(last_user_index=-1) %}
50
+ {%- for m in messages %}
51
+ {%- if m.role == 'user' %}
52
+ {%- set ns.last_user_index = loop.index0 %}
53
+ {%- endif %}
54
+ {%- endfor %}
55
+ {%- for message in messages %}
56
+ {%- set content = message.content %}
57
+ {%- if (message.role == 'system' and not loop.first) %}
58
+ {{- '<|im_start|>' + message.role + '\n' }}
59
+ {%- if content is string %}
60
+ {{- content }}
61
+ {%- elif content is sequence %}
62
+ {%- for part in content %}
63
+ {%- if part.type == 'text' %}
64
+ {{- part.text }}
65
+ {%- endif %}
66
+ {%- endfor %}
67
+ {%- endif %}
68
+ {{- '<|im_end|>' + '\n' }}
69
+ {%- elif message.role == 'user' %}
70
+ {{- '<|im_start|>user\n' }}
71
+ {%- if message['content'] is string %}
72
+ {{- message['content'] + '<|im_end|>\n' }}
73
+ {%- elif message['content'] is sequence %}
74
+ {%- for content in message['content'] %}
75
+ {%- if not loop.first %}
76
+ {{- '\n' }}
77
+ {%- endif %}
78
+ {%- if content['type'] == 'image_url' %}
79
+ {%- set media_url = content.get('image_url', {}).get('url', '') %}
80
+ {%- set url_lower = media_url.lower() %}
81
+ {%- set image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".svg"] %}
82
+ {%- set video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"] %}
83
+ {%- set ns_check = namespace(is_video=False) %}
84
+ {%- for ext in video_extensions %}
85
+ {%- if url_lower.endswith(ext) %}
86
+ {%- set ns_check.is_video = True %}
87
+ {%- endif %}
88
+ {%- endfor %}
89
+ {%- if ns_check.is_video %}
90
+ {%- set video_id = 'video_%02d' % ns_vid.count %}
91
+ {%- set ns_vid.count = ns_vid.count + 1 %}
92
+ {{- '<|mime_start|>{"id": "' + video_id + '", "type": "video/mp4", "filename": "video.mp4"}<|mime_end|>\n' }}
93
+ {{- '<|video_aux_start|>다음 중 video_duration은 비디오 길이 정보입니다. 참고하여 답변하세요. {"video_duration": "<|video_meta_duration|>"}<|video_aux_end|>\n' }}
94
+ {{- '<|video_start|><|VIDEO_PAD|><|video_end|>' }}
95
+ {%- else %}
96
+ {%- set image_id = 'image_%02d' % ns_img.count %}
97
+ {%- set ns_img.count = ns_img.count + 1 %}
98
+ {{- '<|mime_start|>{"id": "' + image_id + '", "type": "image/jpeg", "filename": "image.jpg"}<|mime_end|>\n' }}
99
+ {{- '<|discrete_image_start|><|DISCRETE_IMAGE_PAD|><|discrete_image_end|>\n' }}
100
+ {{- '<|image_start|><|IMAGE_PAD|><|image_end|>' }}
101
+ {%- endif %}
102
+ {%- elif content['type'] == 'input_audio' %}
103
+ {%- set audio_id = 'audio_%02d' % ns_aud.count %}
104
+ {%- set ns_aud.count = ns_aud.count + 1 %}
105
+ {%- set input_audio = content.get('input_audio', {}) %}
106
+ {{- '<|mime_start|>{"id": "' + audio_id + '", "type": "audio/mpeg", "filename": "user_query.wav"}<|mime_end|>\n' }}
107
+ {{- '<|audio_aux_start|>다음 중 audio_duration은 오디오 길이 정보입니다. 참고하여 답변하세요. {"audio_duration": "<|audio_meta_duration|>"}<|audio_aux_end|>\n'}}
108
+ {{- '<|discrete_audio_start|><|DISCRETE_AUDIO_PAD|><|discrete_audio_end|>\n'}}
109
+ {{- '<|audio_start|><|AUDIO_PAD|><|audio_end|>'}}
110
+ {%- elif content['type'] == 'text' %}
111
+ {{- content['text'] }}
112
+ {%- endif %}
113
+ {%- endfor %}
114
+ {{- '<|im_end|>\n'}}
115
+ {%- endif %}
116
+ {%- elif message.role == 'assistant' %}
117
+ {%- set reasoning_content = '' %}
118
+ {%- if message.get('reasoning_content') is string %}
119
+ {%- set reasoning_content = message.get('reasoning_content') %}
120
+ {%- else %}
121
+ {%- if '</think>' in content %}
122
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
123
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
124
+ {%- endif %}
125
+ {%- endif %}
126
+ {%- if loop.index0 > ns.last_user_index %}
127
+ {%- if loop.last or reasoning_content %}
128
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' }}
129
+ {%- else %}
130
+ {{- '<|im_start|>' + message.role + '\n' }}
131
+ {%- endif %}
132
+ {%- else %}
133
+ {{- '<|im_start|>' + message.role + '\n' }}
134
+ {%- endif %}
135
+ {{- content }}
136
+ {%- if message.get('tool_calls') %}
137
+ {%- for tool_call in message.get('tool_calls', []) %}
138
+ {%- if not loop.first or content %}
139
+ {{- '\n' }}
140
+ {%- endif %}
141
+ {%- if tool_call.function %}
142
+ {%- set tool_call = tool_call.function %}
143
+ {%- endif %}
144
+ {{- '<tool_call>' + tool_call.name + '\n' }}
145
+ {%- set _args = tool_call.arguments %}
146
+ {%- for k, v in _args.items() %}
147
+ {{- '<arg_key>' + k + '</arg_key>\n' }}
148
+ {{- '<arg_value>' + (v | tojson if v is not string else v) + '</arg_value>\n' }}
149
+ {%- endfor %}
150
+ {{- '</tool_call>' }}
151
+ {%- endfor %}
152
+ {%- endif %}
153
+ {{- '<|im_end|>\n' }}
154
+ {%- elif message.role == 'tool' %}
155
+ {%- if loop.first or (messages[loop.index0 - 1].role != 'tool') %}
156
+ {{- '<|im_start|>tool' }}
157
+ {%- endif %}
158
+ {{- '\n<tool_response>' + message.get('name', '') + '\n' }}
159
+ {%- if message['content'] is string %}
160
+ {{- content }}
161
+ {%- endif %}
162
+ {{- '\n</tool_response>' }}
163
+ {%- if loop.last or (messages[loop.index0 + 1].role != 'tool') %}
164
+ {{- '<|im_end|>\n' }}
165
+ {%- endif %}
166
+ {%- endif %}
167
+ {%- endfor %}
168
+ {%- if add_generation_prompt %}
169
+ {{- '<|im_start|>assistant\n<think>\n' }}
170
+ {%- if skip_reasoning is defined and skip_reasoning is true %}
171
+ {{- '\n</think>\n\n' }}
172
+ {%- endif %}
173
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "anyres": false,
3
+ "architectures": [
4
+ "HCXVisionV2ForCausalLM"
5
+ ],
6
+ "audio_config": {
7
+ "activation_dropout": 0.0,
8
+ "activation_function": "gelu",
9
+ "add_cross_attention": false,
10
+ "architectures": [
11
+ "Qwen2AudioEncoder"
12
+ ],
13
+ "attention_dropout": 0.0,
14
+ "bad_words_ids": null,
15
+ "begin_suppress_tokens": null,
16
+ "bos_token_id": null,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "d_model": 1280,
20
+ "decoder_start_token_id": null,
21
+ "diversity_penalty": 0.0,
22
+ "do_sample": false,
23
+ "dropout": 0.0,
24
+ "early_stopping": false,
25
+ "encoder_attention_heads": 20,
26
+ "encoder_ffn_dim": 5120,
27
+ "encoder_layerdrop": 0.0,
28
+ "encoder_layers": 32,
29
+ "encoder_no_repeat_ngram_size": 0,
30
+ "eos_token_id": null,
31
+ "exponential_decay_length_penalty": null,
32
+ "finetuning_task": null,
33
+ "forced_bos_token_id": null,
34
+ "forced_eos_token_id": null,
35
+ "id2label": {
36
+ "0": "LABEL_0",
37
+ "1": "LABEL_1"
38
+ },
39
+ "init_std": 0.02,
40
+ "initializer_range": 0.02,
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "length_penalty": 1.0,
48
+ "max_length": 20,
49
+ "max_source_positions": 1500,
50
+ "min_length": 0,
51
+ "model_type": "qwen2_audio_encoder",
52
+ "no_repeat_ngram_size": 0,
53
+ "num_beam_groups": 1,
54
+ "num_beams": 1,
55
+ "num_hidden_layers": 32,
56
+ "num_mel_bins": 128,
57
+ "num_return_sequences": 1,
58
+ "output_attentions": false,
59
+ "output_hidden_states": false,
60
+ "output_scores": false,
61
+ "pad_token_id": null,
62
+ "prefix": null,
63
+ "problem_type": null,
64
+ "pruned_heads": {},
65
+ "remove_invalid_values": false,
66
+ "repetition_penalty": 1.0,
67
+ "return_dict": true,
68
+ "return_dict_in_generate": false,
69
+ "scale_embedding": false,
70
+ "sep_token_id": null,
71
+ "suppress_tokens": null,
72
+ "task_specific_params": null,
73
+ "temperature": 1.0,
74
+ "tf_legacy_loss": false,
75
+ "tie_encoder_decoder": false,
76
+ "tie_word_embeddings": true,
77
+ "tokenizer_class": null,
78
+ "top_k": 50,
79
+ "top_p": 1.0,
80
+ "torch_dtype": "float32",
81
+ "torchscript": false,
82
+ "typical_p": 1.0,
83
+ "use_bfloat16": false
84
+ },
85
+ "audio_model_name_or_path": null,
86
+ "audio_projector_type": "mlp",
87
+ "audio_start_id": 128071,
88
+ "audio_token_id": 128071,
89
+ "auto_map": {
90
+ "AutoConfig": "configuration_vlm.HCXVisionConfig",
91
+ "AutoModelForCausalLM": "modeling_vlm.HCXVisionForCausalLM",
92
+ "AutoModelForSequenceClassification": "modeling_vlm.HCXVisionForSequenceClassification"
93
+ },
94
+ "discrete_audio_config": {
95
+ "model_name_or_path": null,
96
+ "model_type": "cosyvoice2",
97
+ "torch_dtype": "float32"
98
+ },
99
+ "discrete_audio_model_name_or_path": null,
100
+ "discrete_audio_start_id": 128074,
101
+ "discrete_audio_token_id": 128074,
102
+ "discrete_audio_unit_0_id": 128606,
103
+ "discrete_image_start_id": 128069,
104
+ "discrete_image_token_id": 128069,
105
+ "discrete_image_unit_0_id": 135168,
106
+ "discrete_vision_config": {
107
+ "model_name_or_path": null,
108
+ "model_type": "ta_tok",
109
+ "torch_dtype": "float32"
110
+ },
111
+ "discrete_vision_model_name_or_path": null,
112
+ "end_token_id": 128001,
113
+ "eos_token_id": 128001,
114
+ "freeze_audio_projector": true,
115
+ "freeze_before_sampler": false,
116
+ "freeze_decoder": false,
117
+ "freeze_encoder": true,
118
+ "freeze_mm_projector": false,
119
+ "freeze_video_audio_compressor": false,
120
+ "hidden_size": 4096,
121
+ "ignore_index": -100,
122
+ "image_token_id": 128062,
123
+ "img_start_id": 128062,
124
+ "is_safetensor_save": true,
125
+ "max_num_grids": -1,
126
+ "mm_projector_type": "linear",
127
+ "model_type": "vlm",
128
+ "num_queries_vis_abstractor": -1,
129
+ "possible_resolutions": [],
130
+ "proj_pos_emb": true,
131
+ "proj_prenorm": false,
132
+ "q_former_model_name_or_path": null,
133
+ "text_config": {
134
+ "add_cross_attention": false,
135
+ "architectures": [
136
+ "LlamaForCausalLM"
137
+ ],
138
+ "attention_bias": false,
139
+ "attention_dropout": 0.0,
140
+ "bad_words_ids": null,
141
+ "begin_suppress_tokens": null,
142
+ "bos_token_id": 128000,
143
+ "chunk_size_feed_forward": 0,
144
+ "cross_attention_hidden_size": null,
145
+ "decoder_start_token_id": null,
146
+ "diversity_penalty": 0.0,
147
+ "do_sample": false,
148
+ "early_stopping": false,
149
+ "encoder_no_repeat_ngram_size": 0,
150
+ "eos_token_id": 128001,
151
+ "exponential_decay_length_penalty": null,
152
+ "finetuning_task": null,
153
+ "forced_bos_token_id": null,
154
+ "forced_eos_token_id": null,
155
+ "head_dim": 128,
156
+ "hidden_act": "silu",
157
+ "hidden_size": 4096,
158
+ "id2label": {
159
+ "0": "LABEL_0",
160
+ "1": "LABEL_1"
161
+ },
162
+ "initializer_range": 0.02,
163
+ "intermediate_size": 12288,
164
+ "is_decoder": false,
165
+ "is_encoder_decoder": false,
166
+ "label2id": {
167
+ "LABEL_0": 0,
168
+ "LABEL_1": 1
169
+ },
170
+ "length_penalty": 1.0,
171
+ "logits_scaling": 1.0,
172
+ "max_length": 20,
173
+ "max_position_embeddings": 8192,
174
+ "min_length": 0,
175
+ "mlp_bias": false,
176
+ "model_type": "llama",
177
+ "no_repeat_ngram_size": 0,
178
+ "num_attention_heads": 32,
179
+ "num_beam_groups": 1,
180
+ "num_beams": 1,
181
+ "num_hidden_layers": 36,
182
+ "num_key_value_heads": 8,
183
+ "num_return_sequences": 1,
184
+ "output_attentions": false,
185
+ "output_hidden_states": false,
186
+ "output_scores": false,
187
+ "pad_token_id": null,
188
+ "prefix": null,
189
+ "pretraining_tp": 1,
190
+ "problem_type": null,
191
+ "pruned_heads": {},
192
+ "remove_invalid_values": false,
193
+ "repetition_penalty": 1.0,
194
+ "return_dict": true,
195
+ "return_dict_in_generate": false,
196
+ "rms_norm_eps": 1e-06,
197
+ "rope_scaling": null,
198
+ "rope_theta": 5000000,
199
+ "sep_token_id": null,
200
+ "suppress_tokens": null,
201
+ "task_specific_params": null,
202
+ "temperature": 1.0,
203
+ "tf_legacy_loss": false,
204
+ "tie_encoder_decoder": false,
205
+ "tie_word_embeddings": false,
206
+ "tokenizer_class": null,
207
+ "top_k": 50,
208
+ "top_p": 1.0,
209
+ "torch_dtype": "float32",
210
+ "torchscript": false,
211
+ "typical_p": 1.0,
212
+ "use_bfloat16": false,
213
+ "use_cache": true,
214
+ "vocab_size": 200704
215
+ },
216
+ "text_model_name_or_path": null,
217
+ "torch_dtype": "float32",
218
+ "transformers_version": "4.52.4",
219
+ "unpad": false,
220
+ "use_1x1_grid": false,
221
+ "use_nth_layer": -2,
222
+ "video_audio_compressor_type": "mambamia",
223
+ "video_audio_start_id": 128070,
224
+ "video_audio_token_id": 128070,
225
+ "video_first_last_frames_slows": null,
226
+ "video_max_num_frames": 120,
227
+ "video_num_queries_fast": null,
228
+ "video_num_queries_slow": null,
229
+ "video_start_id": 128063,
230
+ "video_token_id": 128063,
231
+ "vision_config": {
232
+ "add_cross_attention": false,
233
+ "anyres": false,
234
+ "architectures": [
235
+ "Qwen2_5_VisionTransformerPretrainedModel"
236
+ ],
237
+ "bad_words_ids": null,
238
+ "begin_suppress_tokens": null,
239
+ "bos_token_id": null,
240
+ "chunk_size_feed_forward": 0,
241
+ "cross_attention_hidden_size": null,
242
+ "decoder_start_token_id": null,
243
+ "depth": 32,
244
+ "diversity_penalty": 0.0,
245
+ "do_sample": false,
246
+ "early_stopping": false,
247
+ "encoder_no_repeat_ngram_size": 0,
248
+ "eos_token_id": null,
249
+ "exponential_decay_length_penalty": null,
250
+ "finetuning_task": null,
251
+ "forced_bos_token_id": null,
252
+ "forced_eos_token_id": null,
253
+ "fullatt_block_indexes": [
254
+ 7,
255
+ 15,
256
+ 23,
257
+ 31
258
+ ],
259
+ "hidden_act": "silu",
260
+ "hidden_size": 1280,
261
+ "id2label": {
262
+ "0": "LABEL_0",
263
+ "1": "LABEL_1"
264
+ },
265
+ "in_channels": 3,
266
+ "in_chans": 3,
267
+ "initializer_range": 0.02,
268
+ "intermediate_size": 3456,
269
+ "is_decoder": false,
270
+ "is_encoder_decoder": false,
271
+ "label2id": {
272
+ "LABEL_0": 0,
273
+ "LABEL_1": 1
274
+ },
275
+ "length_penalty": 1.0,
276
+ "max_length": 20,
277
+ "max_num_grids": -1,
278
+ "min_length": 0,
279
+ "model_type": "qwen2_5_vl",
280
+ "no_repeat_ngram_size": 0,
281
+ "num_beam_groups": 1,
282
+ "num_beams": 1,
283
+ "num_heads": 16,
284
+ "num_return_sequences": 1,
285
+ "out_hidden_size": 5120,
286
+ "output_attentions": false,
287
+ "output_hidden_states": false,
288
+ "output_scores": false,
289
+ "pad_token_id": null,
290
+ "patch_size": 14,
291
+ "prefix": null,
292
+ "problem_type": null,
293
+ "pruned_heads": {},
294
+ "remove_invalid_values": false,
295
+ "repetition_penalty": 1.0,
296
+ "return_dict": true,
297
+ "return_dict_in_generate": false,
298
+ "sep_token_id": null,
299
+ "spatial_merge_size": 2,
300
+ "spatial_patch_size": 14,
301
+ "suppress_tokens": null,
302
+ "task_specific_params": null,
303
+ "temperature": 1.0,
304
+ "temporal_patch_size": 2,
305
+ "tf_legacy_loss": false,
306
+ "tie_encoder_decoder": false,
307
+ "tie_word_embeddings": true,
308
+ "tokenizer_class": null,
309
+ "tokens_per_second": 2,
310
+ "top_k": 50,
311
+ "top_p": 1.0,
312
+ "torch_dtype": "float32",
313
+ "torchscript": false,
314
+ "typical_p": 1.0,
315
+ "use_bfloat16": false,
316
+ "window_size": 112
317
+ },
318
+ "vision_input_chunk_size": null,
319
+ "vision_model_name_or_path": null
320
+ }
configuration_hyperclovax.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """LLaMA model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+
24
+ # from transformers.modeling_rope_utils import rope_config_validation
25
+ # from transformers import PretrainedConfig, rope_config_validation
26
+
27
+
28
+ class HyperCLOVAXConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
31
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
32
+ defaults will yield a similar configuration to that of the LLaMA-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`LlamaModel`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer decoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer decoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
61
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
62
+ Llama 2 up to 4096, CodeLlama up to 16384.
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
66
+ The epsilon used by the rms normalization layers.
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
69
+ relevant if `config.is_decoder=True`.
70
+ pad_token_id (`int`, *optional*):
71
+ Padding token id.
72
+ bos_token_id (`int`, *optional*, defaults to 1):
73
+ Beginning of stream token id.
74
+ eos_token_id (`int`, *optional*, defaults to 2):
75
+ End of stream token id.
76
+ pretraining_tp (`int`, *optional*, defaults to 1):
77
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
78
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
79
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
80
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
81
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82
+ Whether to tie weight embeddings
83
+ rope_theta (`float`, *optional*, defaults to 10000.0):
84
+ The base period of the RoPE embeddings.
85
+ rope_scaling (`Dict`, *optional*):
86
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
87
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
88
+ accordingly.
89
+ Expected contents:
90
+ `rope_type` (`str`):
91
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
92
+ 'llama3'], with 'default' being the original RoPE implementation.
93
+ `factor` (`float`, *optional*):
94
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
95
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
96
+ original maximum pre-trained length.
97
+ `original_max_position_embeddings` (`int`, *optional*):
98
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
99
+ pretraining.
100
+ `attention_factor` (`float`, *optional*):
101
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
102
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
103
+ `factor` field to infer the suggested value.
104
+ `beta_fast` (`float`, *optional*):
105
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
106
+ ramp function. If unspecified, it defaults to 32.
107
+ `beta_slow` (`float`, *optional*):
108
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
109
+ ramp function. If unspecified, it defaults to 1.
110
+ `short_factor` (`List[float]`, *optional*):
111
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
112
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
113
+ size divided by the number of attention heads divided by 2
114
+ `long_factor` (`List[float]`, *optional*):
115
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
116
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
117
+ size divided by the number of attention heads divided by 2
118
+ `low_freq_factor` (`float`, *optional*):
119
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
120
+ `high_freq_factor` (`float`, *optional*):
121
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
122
+ attention_bias (`bool`, *optional*, defaults to `False`):
123
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
124
+ attention_dropout (`float`, *optional*, defaults to 0.0):
125
+ The dropout ratio for the attention probabilities.
126
+ mlp_bias (`bool`, *optional*, defaults to `False`):
127
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
128
+ head_dim (`int`, *optional*):
129
+ The attention head dimension. If None, it will default to hidden_size // num_heads
130
+
131
+ ```python
132
+ >>> from transformers import LlamaModel, LlamaConfig
133
+
134
+ >>> # Initializing a LLaMA llama-7b style configuration
135
+ >>> configuration = LlamaConfig()
136
+
137
+ >>> # Initializing a model from the llama-7b style configuration
138
+ >>> model = LlamaModel(configuration)
139
+
140
+ >>> # Accessing the model configuration
141
+ >>> configuration = model.config
142
+ ```"""
143
+
144
+ model_type = "hyperclovax"
145
+ keys_to_ignore_at_inference = ["past_key_values"]
146
+
147
+ def __init__(
148
+ self,
149
+ vocab_size=32000,
150
+ hidden_size=4096,
151
+ intermediate_size=11008,
152
+ num_hidden_layers=32,
153
+ num_attention_heads=32,
154
+ num_key_value_heads=None,
155
+ hidden_act="silu",
156
+ max_position_embeddings=2048,
157
+ initializer_range=0.02,
158
+ rms_norm_eps=1e-6,
159
+ use_cache=True,
160
+ pad_token_id=None,
161
+ bos_token_id=1,
162
+ eos_token_id=2,
163
+ pretraining_tp=1,
164
+ tie_word_embeddings=False,
165
+ rope_theta=10000.0,
166
+ rope_scaling=None,
167
+ attention_bias=False,
168
+ attention_dropout=0.0,
169
+ mlp_bias=False,
170
+ head_dim=None,
171
+ embedding_multiplier=1.0, # mup
172
+ logits_scaling=1.0, # mup
173
+ attention_multiplier=1.0, # mup
174
+ residual_multiplier=1.0, # mup
175
+ use_post_norm=False, # post-norm
176
+ auto_map={
177
+ "AutoConfig": "configuration_hyperclovax.HyperCLOVAXConfig",
178
+ "AutoModel": "modeling_hyperclovax.HyperCLOVAXModel",
179
+ "AutoModelForCausalLM": "modeling_hyperclovax.HyperCLOVAXForCausalLM",
180
+ },
181
+ **kwargs,
182
+ ):
183
+ self.vocab_size = vocab_size
184
+ self.max_position_embeddings = max_position_embeddings
185
+ self.hidden_size = hidden_size
186
+ self.intermediate_size = intermediate_size
187
+ self.num_hidden_layers = num_hidden_layers
188
+ self.num_attention_heads = num_attention_heads
189
+
190
+ # for backward compatibility
191
+ if num_key_value_heads is None:
192
+ num_key_value_heads = num_attention_heads
193
+
194
+ self.num_key_value_heads = num_key_value_heads
195
+ self.hidden_act = hidden_act
196
+ self.initializer_range = initializer_range
197
+ self.rms_norm_eps = rms_norm_eps
198
+ self.pretraining_tp = pretraining_tp
199
+ self.use_cache = use_cache
200
+ self.rope_theta = rope_theta
201
+ self.rope_scaling = rope_scaling
202
+ self.attention_bias = attention_bias
203
+ self.attention_dropout = attention_dropout
204
+ self.mlp_bias = mlp_bias
205
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
206
+ # Validate the correctness of rotary position embeddings parameters
207
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
208
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
209
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
210
+ # rope_config_validation(self)
211
+
212
+ # mup
213
+ self.embedding_multiplier = embedding_multiplier
214
+ self.logits_scaling = logits_scaling
215
+ self.attention_multiplier = attention_multiplier
216
+ self.residual_multiplier = residual_multiplier
217
+
218
+ # post-norm (dual-norm)
219
+ self.use_post_norm = use_post_norm
220
+
221
+ super().__init__(
222
+ pad_token_id=pad_token_id,
223
+ bos_token_id=bos_token_id,
224
+ eos_token_id=eos_token_id,
225
+ tie_word_embeddings=tie_word_embeddings,
226
+ auto_map=auto_map,
227
+ **kwargs,
228
+ )
configuration_vlm.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import AutoConfig, PretrainedConfig
3
+
4
+
5
+ class HCXVisionConfig(PretrainedConfig):
6
+ model_type = "vlm"
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ text_config=None,
12
+ vision_config=None,
13
+ discrete_vision_config=None,
14
+ audio_config=None,
15
+ discrete_audio_config=None,
16
+ text_model_name_or_path=None,
17
+ vision_model_name_or_path=None,
18
+ discrete_vision_model_name_or_path=None,
19
+ audio_model_name_or_path=None,
20
+ discrete_audio_model_name_or_path=None,
21
+ q_former_model_name_or_path=None,
22
+ mm_projector_type="mlp",
23
+ audio_projector_type="mlp",
24
+ video_audio_compressor_type=None,
25
+ use_nth_layer=-2,
26
+ img_start_id=128062, # <|IMAGE_PAD|> # Manually adjusted value from previous checkpoint
27
+ discrete_image_start_id=128250, # <|DISCRETE_AUDIO_PAD|>
28
+ discrete_image_unit_0_id=135166, # <|vision00000|>
29
+ video_start_id=128063, # <|VIDEO_PAD|>
30
+ video_audio_start_id=None, # <|VIDEO_AUDIO_PAD|> - will be set dynamically
31
+ audio_start_id=128253, # <|AUDIO_PAD|>
32
+ discrete_audio_start_id=128250, # <|DISCRETE_AUDIO_PAD|>
33
+ discrete_audio_unit_0_id=128604, # <|audio0000|>
34
+ freeze_encoder=False,
35
+ freeze_decoder=False,
36
+ freeze_mm_projector=False,
37
+ freeze_audio_projector=False,
38
+ freeze_video_audio_compressor=False,
39
+ anyres=False,
40
+ unpad=False,
41
+ max_num_grids=-1,
42
+ num_queries_vis_abstractor=-1,
43
+ video_num_queries_fast=None,
44
+ video_num_queries_slow=None,
45
+ video_first_last_frames_slows=None,
46
+ video_max_num_frames=None,
47
+ ignore_index=-100,
48
+ proj_pos_emb=True,
49
+ proj_prenorm=False,
50
+ use_1x1_grid=False,
51
+ possible_resolutions=[],
52
+ **kwargs,
53
+ ):
54
+ from transformers import CONFIG_MAPPING
55
+
56
+ if kwargs.get("language_config", None) is not None: # for bc
57
+ text_config = CONFIG_MAPPING[kwargs["language_config"]["model_type"]](**kwargs["language_config"])
58
+ elif text_config is None and text_model_name_or_path is not None:
59
+ text_config = AutoConfig.from_pretrained(text_model_name_or_path, trust_remote_code=True)
60
+ if vision_config is None and vision_model_name_or_path is not None:
61
+ vision_config = AutoConfig.from_pretrained(vision_model_name_or_path, trust_remote_code=True)
62
+ if discrete_vision_config is None and discrete_vision_model_name_or_path is not None:
63
+ discrete_vision_config = {
64
+ "model_type": "ta_tok",
65
+ "model_name_or_path": discrete_vision_model_name_or_path,
66
+ }
67
+ if audio_config is None and audio_model_name_or_path is not None:
68
+ audio_config = AutoConfig.from_pretrained(audio_model_name_or_path)
69
+ if discrete_audio_config is None and discrete_audio_model_name_or_path is not None:
70
+ discrete_audio_config = {
71
+ "model_type": "cosyvoice2",
72
+ "model_name_or_path": discrete_audio_model_name_or_path,
73
+ }
74
+
75
+ if isinstance(text_config, dict):
76
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
77
+
78
+ if isinstance(vision_config, dict):
79
+ if vision_config["model_type"] == "qwen2_5_vl":
80
+ vision_config["model_type"] = "qwen2_5_vl_visual"
81
+ assert transformers.__version__ >= "4.52.4", "please upgrade transformers to 4.52.4 or higher"
82
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
83
+
84
+ if isinstance(audio_config, dict):
85
+ audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config)
86
+
87
+ self.text_config = text_config
88
+ self.vision_config = vision_config
89
+ self.discrete_vision_config = discrete_vision_config
90
+ self.audio_config = audio_config
91
+ self.discrete_audio_config = discrete_audio_config
92
+
93
+ if text_config is not None:
94
+ # deepspeed zero3에서 config의 hidden_size를 보고 메모리 크기를 자동으로 결정함.
95
+ self.hidden_size = text_config.hidden_size if hasattr(text_config, "hidden_size") else text_config.n_embd
96
+ # add VLM configs
97
+ self.text_model_name_or_path = text_model_name_or_path
98
+ self.vision_model_name_or_path = vision_model_name_or_path
99
+ self.discrete_vision_model_name_or_path = discrete_vision_model_name_or_path
100
+ self.audio_model_name_or_path = audio_model_name_or_path
101
+ self.discrete_audio_model_name_or_path = discrete_audio_model_name_or_path
102
+ self.q_former_model_name_or_path = q_former_model_name_or_path
103
+ self.mm_projector_type = mm_projector_type
104
+ self.audio_projector_type = audio_projector_type
105
+ self.video_audio_compressor_type = video_audio_compressor_type
106
+ self.use_nth_layer = use_nth_layer
107
+ self.freeze_encoder = freeze_encoder
108
+ self.freeze_decoder = freeze_decoder
109
+ self.freeze_mm_projector = freeze_mm_projector
110
+ self.freeze_audio_projector = freeze_audio_projector
111
+ self.freeze_video_audio_compressor = freeze_video_audio_compressor
112
+ self.anyres = anyres
113
+ self.unpad = unpad
114
+ self.max_num_grids = max_num_grids
115
+ self.num_queries_vis_abstractor = num_queries_vis_abstractor
116
+ self.video_num_queries_fast = video_num_queries_fast
117
+ self.video_num_queries_slow = video_num_queries_slow
118
+ self.video_first_last_frames_slows = video_first_last_frames_slows
119
+ self.video_max_num_frames = video_max_num_frames
120
+
121
+ self.img_start_id = img_start_id
122
+ self.image_token_id = img_start_id
123
+
124
+ self.discrete_image_start_id = discrete_image_start_id
125
+ self.discrete_image_token_id = discrete_image_start_id
126
+ self.discrete_image_unit_0_id = discrete_image_unit_0_id
127
+
128
+ self.video_start_id = video_start_id
129
+ self.video_token_id = video_start_id
130
+
131
+ self.video_audio_start_id = video_audio_start_id
132
+ self.video_audio_token_id = video_audio_start_id
133
+
134
+ self.audio_start_id = audio_start_id
135
+ self.audio_token_id = audio_start_id
136
+
137
+ self.discrete_audio_start_id = discrete_audio_start_id
138
+ self.discrete_audio_token_id = discrete_audio_start_id
139
+ self.discrete_audio_unit_0_id = discrete_audio_unit_0_id
140
+
141
+ self.ignore_index = ignore_index
142
+ self.proj_pos_emb = proj_pos_emb
143
+ self.proj_prenorm = proj_prenorm
144
+ self.use_1x1_grid = use_1x1_grid
145
+ self.possible_resolutions = possible_resolutions
146
+ super().__init__(**kwargs)
147
+ if self.text_config is not None: # needed for HCXVisionForSequenceClassification
148
+ self.pad_token_id = self.text_config.pad_token_id
149
+
150
+
151
+ AutoConfig.register("vlm", HCXVisionConfig)
152
+ try:
153
+ from .configuration_hyperclovax import HyperCLOVAXConfig
154
+
155
+ AutoConfig.register("hyperclovax", HyperCLOVAXConfig)
156
+ except:
157
+ pass
158
+ try:
159
+ from transformers import CONFIG_MAPPING, MODEL_MAPPING
160
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
161
+ Qwen2_5_VisionTransformerPretrainedModel,
162
+ Qwen2_5_VLPatchMerger,
163
+ Qwen2_5_VLVisionConfig,
164
+ )
165
+
166
+ MODEL_MAPPING.register(Qwen2_5_VLVisionConfig, Qwen2_5_VisionTransformerPretrainedModel)
167
+ CONFIG_MAPPING.register("qwen2_5_vl_visual", Qwen2_5_VLVisionConfig)
168
+ except:
169
+ pass
cosyvoice.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) (Mddct: Dinghao Zhou)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import librosa
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from einops import rearrange
22
+ from torch import nn
23
+
24
+ DEFAULT_SAMPLE_RATE = 16000 # NOTE: 당분간 고정할 예정.
25
+ MIN_DISCRETE_AUDIO_CHUNK_SAMPLES = 1600 # 0.1초, CosyVoice conv 두 번 지나도 code_len >= 1 보장
26
+
27
+
28
+ @dataclass
29
+ class ModelConfig:
30
+ n_mels: int = 128
31
+ n_audio_ctx: int = 1500
32
+ n_audio_state: int = 1280
33
+ n_audio_head: int = 20
34
+ n_audio_layer: int = 6
35
+ n_codebook_size: int = 3**8
36
+
37
+ use_sdpa: bool = True
38
+
39
+
40
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scaling=None):
41
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
42
+ t = torch.arange(end, device=freqs.device) # type: ignore
43
+ if scaling is not None:
44
+ t = t * scaling
45
+ freqs = torch.outer(t, freqs).float() # type: ignore
46
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
47
+
48
+ return torch.cat((freqs_cis, freqs_cis), dim=-1)
49
+
50
+
51
+ def apply_rotary_emb(
52
+ xq: torch.Tensor,
53
+ xk: torch.Tensor,
54
+ freqs_cis: torch.Tensor,
55
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ real = torch.view_as_real(freqs_cis)
57
+ cos, sin = real[:, :, 0], real[:, :, 1]
58
+ cos = cos.unsqueeze(0).unsqueeze(2)
59
+ sin = sin.unsqueeze(0).unsqueeze(2)
60
+
61
+ D = xq.shape[-1]
62
+ half_l, half_r = xq[:, :, :, : D // 2], xq[:, :, :, D // 2 :]
63
+ xq_r = torch.cat((-half_r, half_l), dim=-1)
64
+
65
+ D = xk.shape[-1]
66
+
67
+ half_l, half_r = xk[:, :, :, : D // 2], xk[:, :, :, D // 2 :]
68
+ xk_r = torch.cat((-half_r, half_l), dim=-1)
69
+
70
+ return xq * cos + xq_r * sin, xk * cos + xk_r * sin
71
+
72
+
73
+ class LayerNorm(nn.LayerNorm):
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ return super().forward(x.float()).type(x.dtype)
76
+
77
+
78
+ class Linear(nn.Linear):
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ return F.linear(
81
+ x,
82
+ self.weight.to(x.dtype),
83
+ None if self.bias is None else self.bias.to(x.dtype),
84
+ )
85
+
86
+
87
+ class Conv1d(nn.Conv1d):
88
+ def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
89
+ return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
90
+
91
+
92
+ class MultiHeadAttention(nn.Module):
93
+ def __init__(self, n_state: int, n_head: int, use_sdpa: bool = True):
94
+ super().__init__()
95
+ self.n_head = n_head
96
+ self.query = Linear(n_state, n_state)
97
+ self.key = Linear(n_state, n_state, bias=False)
98
+ self.value = Linear(n_state, n_state)
99
+ self.out = Linear(n_state, n_state)
100
+
101
+ self.use_sdpa = use_sdpa
102
+
103
+ def forward(
104
+ self,
105
+ x: torch.Tensor,
106
+ mask: Optional[torch.Tensor] = None,
107
+ ):
108
+ q = self.query(x)
109
+ k = self.key(x)
110
+ v = self.value(x)
111
+
112
+ wv, qk = self.qkv_attention(q, k, v, mask)
113
+ return self.out(wv), qk
114
+
115
+ def qkv_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None):
116
+ _, _, D = q.shape
117
+ scale = (D // self.n_head) ** -0.25
118
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
119
+ k = k.view(*k.shape[:2], self.n_head, -1)
120
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
121
+
122
+ if not self.use_sdpa:
123
+ k = k.permute(0, 2, 3, 1) * scale
124
+ qk = q @ k # (B, n_head, T, T)
125
+ if mask is not None:
126
+ qk = qk + mask
127
+ qk = qk.float()
128
+ w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
129
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
130
+ else:
131
+ k = k.permute(0, 2, 1, 3) * scale
132
+ assert mask is not None
133
+ output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, scale=1.0)
134
+ output = output.transpose(1, 2).contiguous().view(q.size(0), -1, D) # (batch, time1, d_model)
135
+ return output, None
136
+
137
+
138
+ class FSQCodebook(torch.nn.Module):
139
+ def __init__(self, dim: int, level: int = 3):
140
+ super().__init__()
141
+ self.project_down = torch.nn.Linear(dim, 8)
142
+ self.level = level
143
+ self.embed = None
144
+
145
+ @torch.inference_mode()
146
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
147
+ x = rearrange(x, "... d -> (...) d")
148
+ return x
149
+
150
+ @torch.inference_mode()
151
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
152
+ x_shape = x.shape
153
+ # pre-process
154
+ x = self.preprocess(x)
155
+ # quantize
156
+ h = self.project_down(x).float()
157
+ h = h.tanh()
158
+ h = h * 0.9990000128746033
159
+ h = h.round() + 1
160
+ # h = ((self.level - 1) * h).round() # range [-k, k]
161
+ powers = torch.pow(self.level, torch.arange(2**self.level, device=x.device, dtype=h.dtype))
162
+ mu = torch.sum(h * powers.unsqueeze(0), dim=-1)
163
+ ind = mu.reshape(x_shape[0], x_shape[1]).int()
164
+ return ind
165
+
166
+ @torch.inference_mode()
167
+ def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
168
+ raise NotImplementedError("There is no official up project component provided")
169
+
170
+
171
+ class FSQVectorQuantization(torch.nn.Module):
172
+ """Vector quantization implementation (inference-only).
173
+ Args:
174
+ dim (int): Dimension
175
+ codebook_size (int): Codebook size
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ dim: int,
181
+ codebook_size: int,
182
+ ):
183
+ super().__init__()
184
+ assert 3**8 == codebook_size
185
+ self._codebook = FSQCodebook(dim=dim, level=3)
186
+ self.codebook_size = codebook_size
187
+
188
+ @property
189
+ def codebook(self):
190
+ return self._codebook.embed
191
+
192
+ @torch.inference_mode()
193
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
194
+ return self._codebook.encode(x)
195
+
196
+ @torch.inference_mode()
197
+ def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
198
+ quantize = self._codebook.decode(embed_ind)
199
+ quantize = rearrange(quantize, "b n d -> b d n")
200
+ return quantize
201
+
202
+
203
+ class FSMNMultiHeadAttention(MultiHeadAttention):
204
+ def __init__(
205
+ self,
206
+ n_state: int,
207
+ n_head: int,
208
+ kernel_size: int = 31,
209
+ use_sdpa: bool = True,
210
+ ):
211
+ super().__init__(n_state, n_head)
212
+
213
+ self.fsmn_block = torch.nn.Conv1d(
214
+ n_state, n_state, kernel_size, stride=1, padding=0, groups=n_state, bias=False
215
+ )
216
+ self.left_padding = (kernel_size - 1) // 2
217
+ self.right_padding = kernel_size - 1 - self.left_padding
218
+ self.pad_fn = torch.nn.ConstantPad1d((self.left_padding, self.right_padding), 0.0)
219
+
220
+ self.use_sdpa = use_sdpa
221
+
222
+ def forward_fsmn(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None):
223
+ b, t, _, _ = inputs.size()
224
+ inputs = inputs.view(b, t, -1)
225
+ if mask is not None and mask.size(2) > 0: # time2 > 0
226
+ inputs = inputs * mask
227
+ x = inputs.transpose(1, 2)
228
+ x = self.pad_fn(x)
229
+ x = self.fsmn_block(x)
230
+ x = x.transpose(1, 2)
231
+ x += inputs
232
+ return x * mask
233
+
234
+ def qkv_attention(
235
+ self,
236
+ q: torch.Tensor,
237
+ k: torch.Tensor,
238
+ v: torch.Tensor,
239
+ mask: Optional[torch.Tensor] = None,
240
+ mask_pad: Optional[torch.Tensor] = None,
241
+ freqs_cis: Optional[torch.Tensor] = None,
242
+ ):
243
+ _, _, D = q.shape
244
+ scale = (D // self.n_head) ** -0.25
245
+ q = q.view(*q.shape[:2], self.n_head, -1)
246
+ k = k.view(*k.shape[:2], self.n_head, -1)
247
+ v = v.view(*v.shape[:2], self.n_head, -1)
248
+
249
+ if freqs_cis is not None:
250
+ q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
251
+
252
+ fsm_memory = self.forward_fsmn(v, mask_pad)
253
+
254
+ q = q.permute(0, 2, 1, 3) * scale
255
+ v = v.permute(0, 2, 1, 3)
256
+
257
+ if not self.use_sdpa:
258
+ k = k.permute(0, 2, 3, 1) * scale
259
+ qk = q @ k # (B, n_head, T, T)
260
+ if mask is not None:
261
+ qk = qk + mask
262
+ qk = qk.float()
263
+ w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
264
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach(), fsm_memory
265
+ else:
266
+ k = k.permute(0, 2, 1, 3) * scale
267
+ assert mask is not None
268
+ output = torch.nn.functional.scaled_dot_product_attention(
269
+ q,
270
+ k,
271
+ v,
272
+ attn_mask=mask,
273
+ dropout_p=0.0,
274
+ scale=1.0,
275
+ )
276
+ output = output.transpose(1, 2).contiguous().view(q.size(0), -1, D) # (batch, time1, d_model)
277
+ return output, None, fsm_memory
278
+
279
+ def forward(
280
+ self,
281
+ x: torch.Tensor,
282
+ mask: Optional[torch.Tensor] = None,
283
+ mask_pad: Optional[torch.Tensor] = None,
284
+ freqs_cis: Optional[torch.Tensor] = None,
285
+ ):
286
+ q = self.query(x)
287
+ k = self.key(x)
288
+ v = self.value(x)
289
+
290
+ wv, qk, fsm_memory = self.qkv_attention(q, k, v, mask, mask_pad, freqs_cis)
291
+ return self.out(wv) + fsm_memory, qk
292
+
293
+
294
+ class ResidualAttentionBlock(torch.nn.Module):
295
+ def __init__(
296
+ self,
297
+ n_state: int,
298
+ n_head: int,
299
+ kernel_size: int = 31,
300
+ use_sdpa: bool = False,
301
+ ):
302
+ super().__init__()
303
+
304
+ self.attn = FSMNMultiHeadAttention(n_state, n_head, kernel_size, use_sdpa=use_sdpa)
305
+ self.attn_ln = LayerNorm(n_state, eps=1e-6)
306
+
307
+ n_mlp = n_state * 4
308
+
309
+ self.mlp = torch.nn.Sequential(Linear(n_state, n_mlp), torch.nn.GELU(), Linear(n_mlp, n_state))
310
+ self.mlp_ln = LayerNorm(n_state)
311
+
312
+ def forward(
313
+ self,
314
+ x: torch.Tensor,
315
+ mask: Optional[torch.Tensor] = None,
316
+ mask_pad: Optional[torch.Tensor] = None,
317
+ freqs_cis: Optional[torch.Tensor] = None,
318
+ ):
319
+ x = x + self.attn(self.attn_ln(x), mask=mask, mask_pad=mask_pad, freqs_cis=freqs_cis)[0]
320
+
321
+ x = x + self.mlp(self.mlp_ln(x))
322
+ return x
323
+
324
+
325
+ class AudioEncoderV2(torch.nn.Module):
326
+ def __init__(
327
+ self,
328
+ n_mels: int,
329
+ n_state: int,
330
+ n_head: int,
331
+ n_layer: int,
332
+ stride: int,
333
+ use_sdpa: bool,
334
+ ):
335
+ super().__init__()
336
+ self.stride = stride
337
+
338
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=stride, padding=1)
339
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
340
+ self.freqs_cis = precompute_freqs_cis(64, 1024 * 2)
341
+ self.blocks = torch.nn.ModuleList(
342
+ [ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa) for _ in range(n_layer)]
343
+ )
344
+
345
+ def forward(self, x: torch.Tensor, x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
346
+ """
347
+ x : torch.Tensor, shape = (batch_size, n_mels, T)
348
+ the mel spectrogram of the audio
349
+ x_len: torch.Tensor, shape = (batch_size,)
350
+ length of each audio in x
351
+ """
352
+ mask = self.make_non_pad_mask(x_len).unsqueeze(1)
353
+ x = torch.nn.functional.gelu(self.conv1(x * mask))
354
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
355
+ mask = self.make_non_pad_mask(x_len).unsqueeze(1)
356
+ x = torch.nn.functional.gelu(self.conv2(x * mask))
357
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
358
+ mask = self.make_non_pad_mask(x_len).unsqueeze(1)
359
+ x = x.permute(0, 2, 1) # (B, T // 2, n_state)
360
+ freqs_cis = self.freqs_cis.to(x.device)
361
+ mask_pad = mask.transpose(1, 2)
362
+ mask = self.mask_to_bias(mask, x.dtype)
363
+
364
+ tmp = torch.view_as_real(freqs_cis)
365
+ cos, sin = tmp[:, :, 0], tmp[:, :, 1]
366
+
367
+ cos = torch.cat((cos, cos), dim=-1)
368
+ sin = torch.cat((sin, sin), dim=-1)
369
+ cos = cos.unsqueeze(0).unsqueeze(2)
370
+ sin = sin.unsqueeze(0).unsqueeze(2)
371
+
372
+ for block in self.blocks:
373
+ x = block(x, mask.unsqueeze(1), mask_pad, freqs_cis[: x.size(1)])
374
+
375
+ return x, x_len
376
+
377
+ @staticmethod
378
+ def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
379
+ """Make mask tensor containing indices of non-padded part.
380
+ The sequences in a batch may have different lengths. To enable
381
+ batch computing, padding is need to make all sequence in same
382
+ size. To avoid the padding part pass value to context dependent
383
+ block such as attention or convolution , this padding part is
384
+ masked.
385
+ 1 for non-padded part and 0 for padded part.
386
+ Parameters
387
+ ----------
388
+ lengths (torch.Tensor): Batch of lengths (B,).
389
+ Returns:
390
+ -------
391
+ torch.Tensor: Mask tensor containing indices of padded part (B, max_T).
392
+ Examples:
393
+ >>> import torch
394
+ >>> import s3tokenizer
395
+ >>> lengths = torch.tensor([5, 3, 2])
396
+ >>> masks = s3tokenizer.make_non_pad_mask(lengths)
397
+ masks = [[1, 1, 1, 1, 1],
398
+ [1, 1, 1, 0, 0],
399
+ [1, 1, 0, 0, 0]]
400
+ """
401
+ batch_size = lengths.size(0)
402
+ max_len = max_len if max_len > 0 else lengths.max().item()
403
+ seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
404
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
405
+ seq_length_expand = lengths.unsqueeze(-1)
406
+ mask = seq_range_expand >= seq_length_expand
407
+ return ~mask
408
+
409
+ @staticmethod
410
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
411
+ """Convert bool-tensor to float-tensor for flash attention.
412
+ Parameters
413
+ ----------
414
+ lengths (torch.Tensor): Batch of lengths (B, ?).
415
+ Returns:
416
+ -------
417
+ torch.Tensor: Mask tensor containing indices of padded part (B, ?).
418
+ Examples:
419
+ >>> import torch
420
+ >>> import s3tokenizer
421
+ >>> lengths = torch.tensor([5, 3, 2])
422
+ >>> masks = self.make_non_pad_mask(lengths)
423
+ masks = [[1, 1, 1, 1, 1],
424
+ [1, 1, 1, 0, 0],
425
+ [1, 1, 0, 0, 0]]
426
+ >>> new_masks = self.mask_to_bias(masks, torch.float32)
427
+ new_masks =
428
+ [[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
429
+ [-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10],
430
+ [-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]]
431
+ """
432
+ assert mask.dtype == torch.bool
433
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
434
+ mask = mask.to(dtype)
435
+
436
+ # attention mask bias
437
+ # NOTE(Mddct): torch.finfo jit issues
438
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
439
+ mask = (1.0 - mask) * -1.0e10
440
+ return mask
441
+
442
+
443
+ class CosyvoiceEncoder(nn.Module):
444
+ """S3 tokenizer of the CosyVoice2 implementation (inference-only).
445
+ Args:
446
+ config (ModelConfig): Config
447
+ """
448
+
449
+ def __init__(self, config: ModelConfig = ModelConfig()):
450
+ super().__init__()
451
+ self.config = config
452
+ self.encoder = AudioEncoderV2(
453
+ self.config.n_mels,
454
+ self.config.n_audio_state,
455
+ self.config.n_audio_head,
456
+ self.config.n_audio_layer,
457
+ 2,
458
+ self.config.use_sdpa,
459
+ )
460
+ self.quantizer = FSQVectorQuantization(
461
+ self.config.n_audio_state,
462
+ self.config.n_codebook_size,
463
+ )
464
+
465
+ def forward(self, wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
466
+ mel = self.mel_spectrogram(wav, n_mels=self.config.n_mels)
467
+ mel_len = torch.tensor([mel.shape[-1]]).to(self.device)
468
+ return self.quantize(mel, mel_len)
469
+
470
+ @torch.inference_mode()
471
+ def quantize(self, mel: torch.Tensor, mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
472
+ hidden, code_len = self.encoder(mel, mel_len)
473
+ code = self.quantizer.encode(hidden)
474
+ return code
475
+
476
+ @staticmethod
477
+ def mel_spectrogram(
478
+ wav: torch.Tensor,
479
+ n_mels: int = 80,
480
+ padding: int = 0,
481
+ ) -> torch.Tensor:
482
+ """
483
+ This method is based on the whisper.log_mel_spectrogram().
484
+ So, don't use this as a general mel spectrogram function.
485
+ """
486
+ device = wav.device
487
+ if padding > 0:
488
+ wav = torch.nn.functional.pad(wav, (0, padding))
489
+
490
+ window = torch.hann_window(400).to(device)
491
+ stft = torch.stft(wav, 400, 160, window=window, return_complex=True)
492
+ mag = stft[..., :-1].abs() ** 2
493
+
494
+ filters = torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=n_mels)).to(device)
495
+ mel_spec = filters @ mag
496
+
497
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
498
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
499
+ log_spec = (log_spec + 4.0) / 4.0
500
+ return log_spec
501
+
502
+ @property
503
+ def device(self):
504
+ return next(self.parameters()).device
505
+
506
+ def freeze(self):
507
+ for p in self.parameters():
508
+ p.requires_grad = False
509
+
510
+ @classmethod
511
+ def from_pretrained(cls, model_path: str):
512
+ model = cls()
513
+ model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True)
514
+ model.eval()
515
+ model.freeze()
516
+ return model
decoder/audio/NCCosybigvganDecoder.mar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b71bace1a8ed9f1eac40d98e99ebe9978a25b2de68c25d89674743d550d9abec
3
+ size 517187360
decoder/audio/NCZSCosybigvganDecoder.mar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d3cf46e024952093e19dded52befc61c751e3a759138cc055f8b008d1da34a0
3
+ size 539807544
decoder/vision/model_index.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "VisionTokenToImagePipeline",
3
+ "_diffusers_version": "0.32.2",
4
+ "_custom_pipeline": "pipeline",
5
+ "transformer": [
6
+ "pipeline",
7
+ "VisionTransformer"
8
+ ],
9
+ "vae": [
10
+ "diffusers",
11
+ "AutoencoderKL"
12
+ ],
13
+ "scheduler": [
14
+ "diffusers",
15
+ "FlowMatchEulerDiscreteScheduler"
16
+ ],
17
+ "token_embedder": [
18
+ "pipeline",
19
+ "VisionTokenEmbedder"
20
+ ],
21
+ "transformer2": [
22
+ "pipeline",
23
+ "VisionTransformer"
24
+ ]
25
+ }
decoder/vision/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.35.2",
4
+ "base_image_seq_len": 256,
5
+ "base_shift": 0.5,
6
+ "invert_sigmas": false,
7
+ "max_image_seq_len": 4096,
8
+ "max_shift": 1.15,
9
+ "num_train_timesteps": 1000,
10
+ "shift": 1.0,
11
+ "shift_terminal": null,
12
+ "stochastic_sampling": false,
13
+ "time_shift_type": "exponential",
14
+ "use_beta_sigmas": false,
15
+ "use_dynamic_shifting": false,
16
+ "use_exponential_sigmas": false,
17
+ "use_karras_sigmas": false
18
+ }
decoder/vision/token_embedder/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "VisionTokenEmbedder",
3
+ "_diffusers_version": "0.35.2",
4
+ "embedding_dim": 1536,
5
+ "token_length": 729,
6
+ "vocab_size": 65536
7
+ }
decoder/vision/token_embedder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d84815b2f681f45dbc5dbbad513b04f7fe3fa4444ca6410a9354555bb3410c7f
3
+ size 201329872
decoder/vision/transformer/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "VisionTransformer",
3
+ "_diffusers_version": "0.35.2",
4
+ "axes_dim": [
5
+ 8,
6
+ 36,
7
+ 36
8
+ ],
9
+ "context_in_dim": 1536,
10
+ "depth": 0,
11
+ "depth_single_blocks": 35,
12
+ "guidance_embed": false,
13
+ "hidden_size": 1920,
14
+ "in_channels": 16,
15
+ "mlp_ratio": 4.0,
16
+ "num_heads": 24,
17
+ "qkv_bias": true,
18
+ "theta": 10000,
19
+ "use_patchify": false,
20
+ "vec_in_dim": 1536
21
+ }
decoder/vision/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f07c014f8575090455e9d3c024b4f58a8ad480b957f2c45fc6eec4fc08edbe94
3
+ size 3914661840
decoder/vision/transformer2/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "VisionTransformer",
3
+ "_diffusers_version": "0.35.2",
4
+ "axes_dim": [
5
+ 6,
6
+ 18,
7
+ 18
8
+ ],
9
+ "context_in_dim": 1536,
10
+ "depth": 0,
11
+ "depth_single_blocks": 25,
12
+ "guidance_embed": false,
13
+ "hidden_size": 1008,
14
+ "in_channels": 16,
15
+ "mlp_ratio": 4.0,
16
+ "num_heads": 24,
17
+ "qkv_bias": true,
18
+ "theta": 10000,
19
+ "use_patchify": false,
20
+ "vec_in_dim": 1536
21
+ }
decoder/vision/transformer2/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca76187d0f52336f6d385c4d56f43f0e631a4030343b27647b30baf188bcbc96
3
+ size 777545632
decoder/vision/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.35.2",
4
+ "_name_or_path": "black-forest-labs/FLUX.1-schnell",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 16,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 1024,
28
+ "scaling_factor": 0.3611,
29
+ "shift_factor": 0.1159,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": false,
37
+ "use_quant_conv": false
38
+ }
decoder/vision/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5b59a26851551b67ae1fe58d32e76486e1e812def4696a4bea97f16604d40a3
3
+ size 167666902
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 128000,
4
+ "eos_token_id": 0,
5
+ "transformers_version": "4.52.4"
6
+ }
mambamia_videoaudio_compressor.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # This module is integrated into 'HyperCLOVAX-SEED-Omni-8B' to mitigate
3
+ # audio stream token explosion during hour-long video understanding.
4
+ # It utilizes the MambaMia architecture (AAAI-26 Oral) to
5
+ # effectively compress high-frequency audio tokens into a manageable
6
+ # context for the LLM.
7
+ # Research Context:
8
+ # - MambaMia: https://github.com/naver-ai/mambamia
9
+ # - LLaVA-AV-SSM: https://github.com/naver-ai/LLaVA-AV-SSM
10
+ # Acknowledgements:
11
+ # This implementation is heavily modified and extended from the following
12
+ # foundational repositories:
13
+ # - Transformers: https://github.com/huggingface/transformers (Apache License v2.0)
14
+ # - Mamba: https://github.com/state-spaces/mamba (Apache License v2.0)
15
+
16
+ import math
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+ from transformers.activations import ACT2FN
24
+ from transformers.modeling_utils import PreTrainedModel
25
+ from transformers.configuration_utils import PretrainedConfig
26
+ from transformers.utils import ModelOutput, logging
27
+ from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ # ============================================================================
34
+ # Check for fast path availability
35
+ # ============================================================================
36
+ if is_mamba_2_ssm_available():
37
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
38
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
39
+ else:
40
+ selective_state_update = None
41
+ mamba_split_conv1d_scan_combined = None
42
+
43
+ if is_causal_conv1d_available():
44
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
45
+ else:
46
+ causal_conv1d_update, causal_conv1d_fn = None, None
47
+
48
+ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
49
+
50
+
51
+ # ============================================================================
52
+ # MambaMia2Config (Simplified for v04 only)
53
+ # ============================================================================
54
+ class MambaMia2Config(PretrainedConfig):
55
+ """
56
+ Simplified MambaMia2 configuration for v04 version only.
57
+ """
58
+ model_type = "mamba2"
59
+
60
+ def __init__(
61
+ self,
62
+ num_heads=128,
63
+ head_dim=64,
64
+ vocab_size=32768,
65
+ hidden_size=4096,
66
+ state_size=128,
67
+ num_hidden_layers=64,
68
+ layer_norm_epsilon=1e-5,
69
+ pad_token_id=1,
70
+ bos_token_id=0,
71
+ eos_token_id=2,
72
+ expand=2,
73
+ conv_kernel=4,
74
+ n_groups=8,
75
+ use_bias=False,
76
+ use_conv_bias=True,
77
+ hidden_act="silu",
78
+ initializer_range=0.1,
79
+ residual_in_fp32=False,
80
+ time_step_rank="auto",
81
+ time_step_min=0.001,
82
+ time_step_max=0.1,
83
+ time_step_floor=1e-4,
84
+ time_step_limit=(0.0, float("inf")),
85
+ rescale_prenorm_residual=False,
86
+ use_cache=True,
87
+ norm_before_gate=True,
88
+ rms_norm=True,
89
+ chunk_size=256,
90
+ tie_word_embeddings=False,
91
+ mambamia_chunk_size=10,
92
+ **kwargs,
93
+ ):
94
+ self.vocab_size = vocab_size
95
+ self.hidden_size = hidden_size
96
+ self.state_size = state_size
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.layer_norm_epsilon = layer_norm_epsilon
99
+ self.conv_kernel = conv_kernel
100
+ self.expand = expand
101
+
102
+ self.bos_token_id = bos_token_id
103
+ self.eos_token_id = eos_token_id
104
+ self.pad_token_id = pad_token_id
105
+ self.use_bias = use_bias
106
+ self.use_conv_bias = use_conv_bias
107
+ self.hidden_act = hidden_act
108
+ self.initializer_range = initializer_range
109
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
110
+ self.time_step_min = time_step_min
111
+ self.time_step_max = time_step_max
112
+ self.time_step_floor = time_step_floor
113
+ self.rescale_prenorm_residual = rescale_prenorm_residual
114
+ self.residual_in_fp32 = residual_in_fp32
115
+ self.use_cache = use_cache
116
+ self.n_groups = n_groups
117
+ self.num_heads = num_heads
118
+ self.head_dim = head_dim
119
+ self.norm_before_gate = norm_before_gate
120
+ self.rms_norm = rms_norm
121
+ self.state_size = state_size
122
+ self.chunk_size = chunk_size
123
+ self.time_step_limit = time_step_limit
124
+ self.tie_word_embeddings = tie_word_embeddings
125
+ self.mambamia_chunk_size = mambamia_chunk_size
126
+ self.output_hidden_states = False
127
+ self.output_deltas = False
128
+
129
+ super().__init__(
130
+ bos_token_id=bos_token_id,
131
+ eos_token_id=eos_token_id,
132
+ pad_token_id=pad_token_id,
133
+ tie_word_embeddings=tie_word_embeddings,
134
+ **kwargs,
135
+ )
136
+
137
+
138
+ # ============================================================================
139
+ # Helper Modules
140
+ # ============================================================================
141
+ class MambaRMSNormGated(nn.Module):
142
+ def __init__(self, hidden_size, eps=1e-6):
143
+ super().__init__()
144
+ self.weight = nn.Parameter(torch.ones(hidden_size))
145
+ self.variance_epsilon = eps
146
+
147
+ def forward(self, hidden_states, gate=None):
148
+ input_dtype = hidden_states.dtype
149
+ hidden_states = hidden_states.to(torch.float32)
150
+ if gate is not None:
151
+ hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
152
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
153
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
154
+ return self.weight * hidden_states.to(input_dtype)
155
+
156
+
157
+ class MambaMia2RMSNorm(nn.Module):
158
+ def __init__(self, hidden_size, eps=1e-6):
159
+ super().__init__()
160
+ self.weight = nn.Parameter(torch.ones(hidden_size))
161
+ self.variance_epsilon = eps
162
+
163
+ def forward(self, hidden_states):
164
+ input_dtype = hidden_states.dtype
165
+ hidden_states = hidden_states.to(torch.float32)
166
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
167
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
168
+ return self.weight * hidden_states.to(input_dtype)
169
+
170
+
171
+ # ============================================================================
172
+ # MambaMia2Mixer (v04 version - unidirectional with GPA)
173
+ # ============================================================================
174
+ class MambaMia2Mixer(nn.Module):
175
+ """
176
+ Unidirectional Mamba2 Mixer for v04 version.
177
+ v04 = v0 (unidirectional Mamba) + GPA (Gated Pooling Attention in Block)
178
+ """
179
+
180
+ def __init__(self, config: MambaMia2Config, layer_idx: int):
181
+ super().__init__()
182
+ self.num_heads = config.num_heads
183
+ self.hidden_size = config.hidden_size
184
+ self.ssm_state_size = config.state_size
185
+ self.conv_kernel_size = config.conv_kernel
186
+ self.intermediate_size = int(config.expand * self.hidden_size)
187
+ self.time_step_rank = int(config.time_step_rank)
188
+ self.layer_idx = layer_idx
189
+ self.use_conv_bias = config.use_conv_bias
190
+ self.activation = config.hidden_act
191
+ self.act = ACT2FN[config.hidden_act]
192
+
193
+ self.norm_before_gate = config.norm_before_gate
194
+ self.layer_norm_epsilon = config.layer_norm_epsilon
195
+ self.rms_norm = config.rms_norm
196
+
197
+ self.n_groups = config.n_groups
198
+ self.head_dim = config.head_dim
199
+ self.chunk_size = config.chunk_size
200
+
201
+ self.time_step_limit = config.time_step_limit
202
+ self.time_step_min = config.time_step_min
203
+ self.time_step_max = config.time_step_max
204
+
205
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
206
+
207
+ # Conv1d for SSM
208
+ self.conv1d = nn.Conv1d(
209
+ in_channels=self.conv_dim,
210
+ out_channels=self.conv_dim,
211
+ bias=config.use_conv_bias,
212
+ kernel_size=config.conv_kernel,
213
+ groups=self.conv_dim,
214
+ padding=config.conv_kernel - 1,
215
+ )
216
+
217
+ # projection of the input hidden states
218
+ projection_size = self.intermediate_size + self.conv_dim + self.num_heads
219
+ self.in_proj = nn.Linear(self.hidden_size, projection_size, bias=config.use_bias)
220
+
221
+ # time step projection
222
+ self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
223
+
224
+ # S4D real initialization
225
+ A = torch.arange(1, self.num_heads + 1)
226
+ self.A_log = nn.Parameter(torch.log(A))
227
+ self.A_log._no_weight_decay = True
228
+ self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
229
+ self.D = nn.Parameter(torch.ones(self.num_heads))
230
+ self.D._no_weight_decay = True
231
+
232
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
233
+ self.use_bias = config.use_bias
234
+
235
+ if not is_fast_path_available:
236
+ logger.warning_once(
237
+ "The fast path is not available because one of "
238
+ "`(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. "
239
+ "Falling back to the naive implementation. To install follow "
240
+ "https://github.com/state-spaces/mamba/#installation and "
241
+ "https://github.com/Dao-AILab/causal-conv1d"
242
+ )
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.Tensor,
247
+ attention_mask: Optional[torch.Tensor] = None,
248
+ ):
249
+ """
250
+ v04 unidirectional forward pass using CUDA kernels.
251
+ """
252
+ import os
253
+ rank = int(os.environ.get("RANK", -1))
254
+ debug = False # (rank <= 0)
255
+
256
+ assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, \
257
+ "CUDA kernels required for MambaMia2Mixer"
258
+
259
+ dtype = hidden_states.dtype
260
+ batch_size, seq_len, _ = hidden_states.shape
261
+
262
+ if debug:
263
+ print(f"[Mixer DEBUG] input: min={hidden_states.min().item():.6f}, max={hidden_states.max().item():.6f}, nan={torch.isnan(hidden_states).any().item()}, seq_len={seq_len}, chunk_size={self.chunk_size}")
264
+
265
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
266
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
267
+
268
+ # Gated MLP's linear projection
269
+ projected_states = self.in_proj(hidden_states)
270
+
271
+ if debug:
272
+ print(f"[Mixer DEBUG] after in_proj: min={projected_states.min().item():.6f}, max={projected_states.max().item():.6f}, nan={torch.isnan(projected_states).any().item()}")
273
+ print(f"[Mixer DEBUG] A_log: {self.A_log[:5].tolist()}, dt_bias: {self.dt_bias[:5].tolist()}, D: {self.D[:5].tolist()}")
274
+
275
+ dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
276
+
277
+ # Unidirectional forward pass (same as v0)
278
+ outputs = mamba_split_conv1d_scan_combined(
279
+ projected_states,
280
+ self.conv1d.weight.squeeze(1),
281
+ self.conv1d.bias,
282
+ self.dt_bias,
283
+ -torch.exp(self.A_log.float()),
284
+ D=self.D,
285
+ chunk_size=self.chunk_size,
286
+ seq_idx=None,
287
+ activation=self.activation,
288
+ rmsnorm_weight=self.norm.weight,
289
+ rmsnorm_eps=self.norm.variance_epsilon,
290
+ outproj_weight=self.out_proj.weight,
291
+ outproj_bias=self.out_proj.bias,
292
+ headdim=self.head_dim,
293
+ ngroups=self.n_groups,
294
+ norm_before_gate=self.norm_before_gate,
295
+ return_final_states=False,
296
+ **dt_limit_kwargs,
297
+ )
298
+
299
+ if debug:
300
+ print(f"[Mixer DEBUG] after mamba_kernel: min={outputs.min().item():.6f}, max={outputs.max().item():.6f}, nan={torch.isnan(outputs).any().item()}")
301
+
302
+ return outputs.to(dtype)
303
+
304
+
305
+ # ============================================================================
306
+ # MambaMia2Block (v04 version only)
307
+ # ============================================================================
308
+ class MambaMia2Block(nn.Module):
309
+ """
310
+ Single MambaMia2 block with v04 gated pooling attention mechanism.
311
+ """
312
+
313
+ def __init__(self, config: MambaMia2Config, layer_idx: int):
314
+ super().__init__()
315
+ self.config = config
316
+ self.layer_idx = layer_idx
317
+ self.residual_in_fp32 = config.residual_in_fp32
318
+ self.mambamia_chunk_size = config.mambamia_chunk_size
319
+
320
+ self.norm = MambaMia2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
321
+ self.mixer = MambaMia2Mixer(config, layer_idx=layer_idx)
322
+
323
+ # v04 specific: Gated Pooling Attention (GPA)
324
+ self.drop = nn.Dropout(p=0.1)
325
+
326
+ # Per-frame weight prediction
327
+ self.weight_fc = nn.Linear(config.hidden_size, self.mambamia_chunk_size)
328
+ nn.init.zeros_(self.weight_fc.bias)
329
+ with torch.no_grad():
330
+ self.weight_fc.weight.mul_(1e-3)
331
+
332
+ # Query vs aggregator gating
333
+ self.gate_fc = nn.Linear(config.hidden_size, 1)
334
+ nn.init.zeros_(self.gate_fc.bias)
335
+ with torch.no_grad():
336
+ self.gate_fc.weight.mul_(1e-3)
337
+
338
+ def forward(
339
+ self,
340
+ hidden_states: torch.Tensor,
341
+ attention_mask: Optional[torch.Tensor] = None,
342
+ ):
343
+ input_dtype = hidden_states.dtype
344
+ residual = hidden_states
345
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
346
+ if self.residual_in_fp32:
347
+ residual = residual.to(torch.float32)
348
+
349
+ # v04 Gated Pooling Attention
350
+ assert hidden_states.dim() == 3, f"hidden_states.dim()={hidden_states.dim()} != 3"
351
+ bsz, seq_len, hidden_dim = hidden_states.shape
352
+ mambamia_chunk_size = self.mambamia_chunk_size
353
+ chunk_with_query = mambamia_chunk_size + 1
354
+
355
+ if seq_len % chunk_with_query != 0:
356
+ raise ValueError(
357
+ f"seq_len={seq_len} must be divisible by (mambamia_chunk_size+1)={chunk_with_query}"
358
+ )
359
+ n_chunk = seq_len // chunk_with_query
360
+
361
+ # Reshape to (bsz, n_chunk, chunk_size+1, hidden_dim)
362
+ hidden_4d = hidden_states.view(bsz, n_chunk, chunk_with_query, hidden_dim)
363
+
364
+ frames = hidden_4d[:, :, :mambamia_chunk_size, :] # (bsz, n_chunk, chunk_size, hidden_dim)
365
+ queries = hidden_4d[:, :, mambamia_chunk_size, :] # (bsz, n_chunk, hidden_dim)
366
+
367
+ # Weight prediction for frames (float32로 계산하여 안정성 확보)
368
+ w_in = self.drop(queries)
369
+ raw_weights = self.weight_fc(w_in)
370
+ alpha = torch.softmax(raw_weights.float(), dim=-1).to(input_dtype) # (bsz, n_chunk, chunk_size)
371
+
372
+ # Weighted average: aggregator
373
+ aggregator = (frames * alpha.unsqueeze(-1)).sum(dim=2) # (bsz, n_chunk, hidden_dim)
374
+
375
+ # Gating between queries and aggregator (float32로 계산)
376
+ gating_in = self.drop(queries)
377
+ gating = torch.sigmoid(self.gate_fc(gating_in).float()).to(input_dtype) # (bsz, n_chunk, 1)
378
+ epsilon = 0.01
379
+ gating = gating * (1 - 2 * epsilon) + epsilon # [0.01, 0.99]
380
+
381
+ gating_broad = gating.expand(-1, -1, hidden_dim)
382
+ aggregator = aggregator * gating_broad
383
+ queries = queries * (1 - gating_broad)
384
+ queries_new = queries + aggregator
385
+
386
+ # Update query positions
387
+ hidden_4d = hidden_4d.clone()
388
+ hidden_4d[:, :, mambamia_chunk_size, :] = queries_new
389
+ hidden_states = hidden_4d.view(bsz, seq_len, hidden_dim)
390
+
391
+ # Mixer forward
392
+ hidden_states = self.mixer(hidden_states, attention_mask=attention_mask)
393
+
394
+ # Residual connection
395
+ hidden_states = hidden_states + residual
396
+
397
+ return hidden_states
398
+
399
+
400
+ # ============================================================================
401
+ # MambaMia2Model (Simplified)
402
+ # ============================================================================
403
+ @dataclass
404
+ class MambaMia2Output(ModelOutput):
405
+ """Output class for MambaMia2Model."""
406
+ last_hidden_state: Optional[torch.FloatTensor] = None
407
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
408
+
409
+
410
+ class MambaMia2PreTrainedModel(PreTrainedModel):
411
+ """Base class for MambaMia2 models."""
412
+ config_class = MambaMia2Config
413
+ base_model_prefix = "backbone"
414
+ _no_split_modules = ["MambaMia2Block"]
415
+ supports_gradient_checkpointing = True
416
+
417
+ def _init_weights(self, module):
418
+ if isinstance(module, MambaMia2Mixer):
419
+ module.A_log._no_weight_decay = True
420
+ module.D._no_weight_decay = True
421
+
422
+ dt = torch.exp(
423
+ torch.rand(self.config.num_heads)
424
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
425
+ + math.log(self.config.time_step_min)
426
+ ).clamp(min=self.config.time_step_floor)
427
+
428
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
429
+ with torch.no_grad():
430
+ module.dt_bias.copy_(inv_dt)
431
+ module.dt_bias._no_reinit = True
432
+
433
+ if isinstance(module, nn.Linear):
434
+ if module.bias is not None:
435
+ if not getattr(module.bias, "_no_reinit", False):
436
+ nn.init.zeros_(module.bias)
437
+ elif isinstance(module, nn.Embedding):
438
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
439
+
440
+ if self.config.rescale_prenorm_residual:
441
+ for name, p in module.named_parameters():
442
+ if name in ["out_proj.weight"]:
443
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
444
+ with torch.no_grad():
445
+ p /= math.sqrt(self.config.num_hidden_layers)
446
+
447
+
448
+ class MambaMia2Model(MambaMia2PreTrainedModel):
449
+ """
450
+ Simplified MambaMia2 Model for v04 version.
451
+ Takes inputs_embeds directly (no embedding layer used for audio/video).
452
+ """
453
+
454
+ def __init__(self, config: MambaMia2Config):
455
+ super().__init__(config)
456
+ self.layers = nn.ModuleList([
457
+ MambaMia2Block(config, layer_idx=idx)
458
+ for idx in range(config.num_hidden_layers)
459
+ ])
460
+ self.gradient_checkpointing = False
461
+ self.norm_f = MambaMia2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
462
+ self.post_init()
463
+
464
+ def forward(
465
+ self,
466
+ inputs_embeds: torch.Tensor,
467
+ attention_mask: Optional[torch.Tensor] = None,
468
+ output_hidden_states: Optional[bool] = None,
469
+ return_dict: Optional[bool] = None,
470
+ ) -> Union[Tuple, MambaMia2Output]:
471
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
472
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
473
+
474
+ hidden_states = inputs_embeds
475
+ all_hidden_states = () if output_hidden_states else None
476
+
477
+ for mixer_block in self.layers:
478
+ if self.gradient_checkpointing and self.training:
479
+ hidden_states = self._gradient_checkpointing_func(
480
+ mixer_block.__call__, hidden_states, attention_mask
481
+ )
482
+ else:
483
+ hidden_states = mixer_block(hidden_states, attention_mask=attention_mask)
484
+
485
+ if output_hidden_states:
486
+ all_hidden_states = all_hidden_states + (hidden_states,)
487
+
488
+ hidden_states = self.norm_f(hidden_states)
489
+
490
+ if output_hidden_states:
491
+ all_hidden_states = all_hidden_states + (hidden_states,)
492
+
493
+ if not return_dict:
494
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
495
+
496
+ return MambaMia2Output(
497
+ last_hidden_state=hidden_states,
498
+ hidden_states=all_hidden_states,
499
+ )
500
+
501
+
502
+ # ============================================================================
503
+ # MambaMiaVideoAudioCompressorConfig
504
+ # ============================================================================
505
+ class MambaMiaVideoAudioCompressorConfig(PretrainedConfig):
506
+ """
507
+ Configuration for MambaMiaVideoAudioCompressor.
508
+
509
+ Args:
510
+ input_size: Input embedding dimension (e.g., 1280 for Whisper)
511
+ output_size: Output embedding dimension (e.g., 2048 for LLM)
512
+ chunk_size: Number of tokens per chunk (default: 25, i.e., 1 second at 25Hz)
513
+ num_hidden_layers: Number of MambaMia2 layers (default: 1)
514
+ hidden_size: Internal hidden size (default: 3072, must be divisible by 24)
515
+ """
516
+ model_type = "mambamia_videoaudio_compressor"
517
+
518
+ def __init__(
519
+ self,
520
+ input_size: int = 1280,
521
+ output_size: int = 2048,
522
+ chunk_size: int = 25,
523
+ num_hidden_layers: int = 1,
524
+ hidden_size: int = 3072,
525
+ **kwargs,
526
+ ):
527
+ super().__init__(**kwargs)
528
+ self.input_size = input_size
529
+ self.output_size = output_size
530
+ self.chunk_size = chunk_size
531
+ self.num_hidden_layers = num_hidden_layers
532
+ self.hidden_size = hidden_size
533
+
534
+
535
+ # ============================================================================
536
+ # MambaMiaVideoAudioCompressor - Main Interface (PreTrainedModel 기반)
537
+ # ============================================================================
538
+ class MambaMiaVideoAudioCompressor(PreTrainedModel):
539
+ """
540
+ Video/Audio Compressor using MambaMia2 (v04 bidirectional version).
541
+
542
+ This module compresses sequential embeddings (e.g., audio frames at 25Hz)
543
+ by inserting learnable query tokens and extracting them after processing.
544
+
545
+ Args:
546
+ config: MambaMiaVideoAudioCompressorConfig
547
+
548
+ Input:
549
+ inputs_embeds: (batch_size, num_frames, hidden_dim) where num_frames is
550
+ typically the audio length and hidden_dim matches input_size
551
+
552
+ Output:
553
+ compressed_embeds: (batch_size, num_queries, output_size) where
554
+ num_queries = num_frames // chunk_size
555
+ """
556
+
557
+ config_class = MambaMiaVideoAudioCompressorConfig
558
+ base_model_prefix = "mambamia_compressor"
559
+ _no_split_modules = ["MambaMia2Block"]
560
+
561
+ def __init__(self, config: MambaMiaVideoAudioCompressorConfig):
562
+ super().__init__(config)
563
+
564
+ self.input_size = config.input_size
565
+ self.output_size = config.output_size
566
+ self.chunk_size = config.chunk_size
567
+ self.hidden_size = config.hidden_size
568
+
569
+ # Input projection: input_size -> hidden_size
570
+ self.input_proj = nn.Linear(config.input_size, config.hidden_size)
571
+
572
+ # Learnable query token
573
+ self.query_token = nn.Parameter(torch.randn(config.hidden_size))
574
+
575
+ # MambaMia2 backbone
576
+ # 중요: chunk_size는 SSM kernel의 chunk size로, 시퀀스 길이보다 작아야 함
577
+ # mambamia_chunk_size는 압축 비율 (25:1)
578
+ # 시퀀스 길이가 짧을 수 있으므로 (예: 390 tokens) chunk_size=64로 설정
579
+ mamba_config = MambaMia2Config(
580
+ vocab_size=0,
581
+ hidden_size=config.hidden_size,
582
+ num_hidden_layers=config.num_hidden_layers,
583
+ head_dim=64,
584
+ num_heads=config.hidden_size * 2 // 64, # e.g., 3072*2/64 = 96
585
+ n_groups=1,
586
+ expand=2.0,
587
+ use_cache=False,
588
+ chunk_size=256, # SSM kernel chunk size
589
+ mambamia_chunk_size=config.chunk_size, # 압축 비율용 (25)
590
+ residual_in_fp32=False,
591
+ )
592
+ self.model = MambaMia2Model(mamba_config)
593
+
594
+ # LayerNorm before Mamba2 to normalize input scales
595
+ # This ensures query_token and input_proj outputs are on the same scale
596
+ self.input_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
597
+
598
+ # Output projection: hidden_size -> output_size
599
+ self.output_proj = nn.Linear(config.hidden_size, config.output_size)
600
+
601
+ # Initialize weights (transformers style)
602
+ self.post_init()
603
+
604
+ def _init_weights(self, module):
605
+ """
606
+ Initialize weights - called by post_init() for all submodules.
607
+ 주의: MambaMia2Model 내부의 가중치는 건드리지 않음 (자체 post_init에서 처리됨)
608
+ """
609
+ # query_token 초기화 - std=1.0으로 input_proj 출력 스케일과 맞춤
610
+ # (작은 std는 LayerNorm에서 variance가 0에 가까워져 inf 발생)
611
+ if module is self:
612
+ with torch.no_grad():
613
+ self.query_token.data.normal_(mean=0.0, std=1.0)
614
+
615
+ # input_proj, output_proj만 xavier 초기화 (MambaMia2 내부는 건드리지 않음)
616
+ if module is self.input_proj or module is self.output_proj:
617
+ nn.init.xavier_uniform_(module.weight)
618
+ if module.bias is not None:
619
+ nn.init.zeros_(module.bias)
620
+
621
+ def _init_all_weights(self):
622
+ """
623
+ Force re-initialize all weights. Call after dtype conversion for FSDP compatibility.
624
+ This ensures weights are properly initialized even after model transformations.
625
+ """
626
+ # 1. input_proj, output_proj 초기화
627
+ nn.init.xavier_uniform_(self.input_proj.weight)
628
+ if self.input_proj.bias is not None:
629
+ nn.init.zeros_(self.input_proj.bias)
630
+ nn.init.xavier_uniform_(self.output_proj.weight)
631
+ if self.output_proj.bias is not None:
632
+ nn.init.zeros_(self.output_proj.bias)
633
+
634
+ # 2. query_token 초기화 - std=1.0으로 input_proj 출력 스케일과 맞춤
635
+ self.query_token.data.normal_(mean=0.0, std=1.0)
636
+
637
+ # 3. input_norm (LayerNorm) 초기화
638
+ nn.init.ones_(self.input_norm.weight)
639
+ nn.init.zeros_(self.input_norm.bias)
640
+
641
+ # 4. MambaMia2Model 내부 초기화 (중요!)
642
+ for name, module in self.model.named_modules():
643
+ if isinstance(module, nn.Linear):
644
+ nn.init.xavier_uniform_(module.weight)
645
+ if module.bias is not None:
646
+ nn.init.zeros_(module.bias)
647
+ elif isinstance(module, nn.Conv1d):
648
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
649
+ if module.bias is not None:
650
+ nn.init.zeros_(module.bias)
651
+
652
+ # 5. MambaMia2Block의 특수 초기화 (weight_fc, gate_fc)
653
+ for layer in self.model.layers:
654
+ if hasattr(layer, 'weight_fc'):
655
+ nn.init.xavier_uniform_(layer.weight_fc.weight)
656
+ layer.weight_fc.weight.data.mul_(0.01) # Scale down
657
+ nn.init.zeros_(layer.weight_fc.bias)
658
+ if hasattr(layer, 'gate_fc'):
659
+ nn.init.xavier_uniform_(layer.gate_fc.weight)
660
+ layer.gate_fc.weight.data.mul_(0.01) # Scale down
661
+ nn.init.zeros_(layer.gate_fc.bias)
662
+
663
+ # 6. A_log, D, dt_bias 파라미터 초기화 (SSM specific)
664
+ for layer in self.model.layers:
665
+ if hasattr(layer, 'mixer'):
666
+ mixer = layer.mixer
667
+ # A_log: S4D real initialization
668
+ A = torch.arange(1, mixer.num_heads + 1, dtype=mixer.A_log.dtype, device=mixer.A_log.device)
669
+ mixer.A_log.data.copy_(torch.log(A))
670
+ # D: scaling factor
671
+ mixer.D.data.fill_(1.0)
672
+ # dt_bias: time step bias (중요!)
673
+ mixer.dt_bias.data.fill_(1.0)
674
+
675
+ # 7. RMSNorm weight 초기화 (MambaRMSNormGated)
676
+ for layer in self.model.layers:
677
+ if hasattr(layer, 'mixer') and hasattr(layer.mixer, 'norm'):
678
+ layer.mixer.norm.weight.data.fill_(1.0)
679
+ if hasattr(layer, 'norm') and hasattr(layer.norm, 'weight'):
680
+ layer.norm.weight.data.fill_(1.0)
681
+
682
+ # 8. MambaMia2Model의 최종 norm_f 초기화
683
+ if hasattr(self.model, 'norm_f') and hasattr(self.model.norm_f, 'weight'):
684
+ self.model.norm_f.weight.data.fill_(1.0)
685
+
686
+ def forward(
687
+ self,
688
+ inputs_embeds: torch.Tensor,
689
+ ) -> torch.Tensor:
690
+ """
691
+ Forward pass.
692
+
693
+ Args:
694
+ inputs_embeds: (batch_size, seq_len, input_size) or
695
+ (batch_size, num_frames, chunk_size, input_size)
696
+
697
+ Returns:
698
+ compressed: (batch_size, num_queries, output_size)
699
+ """
700
+ import os
701
+ rank = int(os.environ.get("RANK", -1))
702
+ debug = False # True if (rank <= 0) else False
703
+
704
+ # Handle different input shapes
705
+ if inputs_embeds.dim() == 4:
706
+ # (batch_size, num_frames, chunk_size, input_size)
707
+ bsz, num_frames, chunk_size, _ = inputs_embeds.shape
708
+ assert chunk_size == self.chunk_size, \
709
+ f"Input chunk_size {chunk_size} != expected {self.chunk_size}"
710
+ inputs_embeds = inputs_embeds.view(bsz, -1, self.input_size)
711
+
712
+ bsz, seq_len, _ = inputs_embeds.shape
713
+
714
+ # Ensure seq_len is divisible by chunk_size
715
+ if seq_len % self.chunk_size != 0:
716
+ # Pad to make divisible
717
+ pad_len = self.chunk_size - (seq_len % self.chunk_size)
718
+ inputs_embeds = F.pad(inputs_embeds, (0, 0, 0, pad_len))
719
+ seq_len = inputs_embeds.shape[1]
720
+
721
+ n_chunk = seq_len // self.chunk_size
722
+
723
+ # Project input
724
+ hidden_states = self.input_proj(inputs_embeds) # (bsz, seq_len, hidden_size)
725
+
726
+ if debug:
727
+ print(f"[MambaMia DEBUG] input_proj output: min={hidden_states.min().item():.6f}, max={hidden_states.max().item():.6f}, has_nan={torch.isnan(hidden_states).any().item()}")
728
+
729
+ # Reshape to (bsz, n_chunk, chunk_size, hidden_size)
730
+ hidden_4d = hidden_states.view(bsz, n_chunk, self.chunk_size, self.hidden_size)
731
+
732
+ # Add query token to each chunk
733
+ # query_token: (hidden_size,) -> (1, 1, 1, hidden_size)
734
+ query_expanded = self.query_token.view(1, 1, 1, -1).expand(bsz, n_chunk, 1, self.hidden_size)
735
+
736
+ if debug:
737
+ print(f"[MambaMia DEBUG] query_token: min={self.query_token.min().item():.6f}, max={self.query_token.max().item():.6f}, has_nan={torch.isnan(self.query_token).any().item()}")
738
+
739
+ # Concatenate: (bsz, n_chunk, chunk_size+1, hidden_size)
740
+ hidden_with_query = torch.cat([hidden_4d, query_expanded], dim=2)
741
+
742
+ # Flatten for model: (bsz, n_chunk * (chunk_size+1), hidden_size)
743
+ model_input = hidden_with_query.view(bsz, -1, self.hidden_size)
744
+
745
+ # Apply LayerNorm to normalize input scales before Mamba2
746
+ model_input = self.input_norm(model_input)
747
+
748
+ if debug:
749
+ print(f"[MambaMia DEBUG] model_input (after LayerNorm, before Mamba2): min={model_input.min().item():.6f}, max={model_input.max().item():.6f}, has_nan={torch.isnan(model_input).any().item()}")
750
+
751
+ # Forward through MambaMia2
752
+ outputs = self.model(inputs_embeds=model_input)
753
+ hidden_states = outputs.last_hidden_state # (bsz, n_chunk * (chunk_size+1), hidden_size)
754
+
755
+ if debug:
756
+ print(f"[MambaMia DEBUG] model output (after Mamba2): min={hidden_states.min().item():.6f}, max={hidden_states.max().item():.6f}, has_nan={torch.isnan(hidden_states).any().item()}")
757
+
758
+ # Check for NaN and replace with zeros if found (defensive)
759
+ if torch.isnan(hidden_states).any():
760
+ hidden_states = torch.nan_to_num(hidden_states, nan=0.0)
761
+
762
+ # Reshape back: (bsz, n_chunk, chunk_size+1, hidden_size)
763
+ hidden_out_4d = hidden_states.view(bsz, n_chunk, self.chunk_size + 1, self.hidden_size)
764
+
765
+ # Extract query positions (last position in each chunk)
766
+ query_outputs = hidden_out_4d[:, :, self.chunk_size, :] # (bsz, n_chunk, hidden_size)
767
+
768
+ if debug:
769
+ print(f"[MambaMia DEBUG] query_outputs (extracted): min={query_outputs.min().item():.6f}, max={query_outputs.max().item():.6f}, has_nan={torch.isnan(query_outputs).any().item()}")
770
+
771
+ # Project to output size
772
+ compressed = self.output_proj(query_outputs) # (bsz, n_chunk, output_size)
773
+
774
+ if debug:
775
+ print(f"[MambaMia DEBUG] output_proj output: min={compressed.min().item():.6f}, max={compressed.max().item():.6f}, has_nan={torch.isnan(compressed).any().item()}")
776
+
777
+ return compressed
778
+
779
+
780
+ # ============================================================================
781
+ # Convenience function for quick instantiation
782
+ # ============================================================================
783
+ def create_mambamia_compressor(
784
+ input_size: int,
785
+ output_size: int,
786
+ chunk_size: int = 25,
787
+ num_hidden_layers: int = 2,
788
+ hidden_size: int = 3072,
789
+ ) -> MambaMiaVideoAudioCompressor:
790
+ """
791
+ Create a MambaMiaVideoAudioCompressor with default settings.
792
+
793
+ Example:
794
+ compressor = create_mambamia_compressor(1280, 2048, chunk_size=25)
795
+ """
796
+ config = MambaMiaVideoAudioCompressorConfig(
797
+ input_size=input_size,
798
+ output_size=output_size,
799
+ chunk_size=chunk_size,
800
+ num_hidden_layers=num_hidden_layers,
801
+ hidden_size=hidden_size,
802
+ )
803
+ return MambaMiaVideoAudioCompressor(config)
model-00001-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f0c21a9149c81295ffd42490cfb2daf7c9e6dcc39f11a4e4c4b4fe4be8a9e2a
3
+ size 4707522584
model-00002-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91399f86327e37a1b75ca14154cdc42d9994f39c08c34e8156503e89f10cc800
3
+ size 3454903840
model-00003-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f779b660d0a569700efef179dda58a3496decb3049b8c8269860f83b32bb647f
3
+ size 4999679056
model-00004-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73022915ccce7a97bb31c0cc60c7b7b52fe7cd4a6859949a1e8f3d60c51e2fb8
3
+ size 4832042296
model-00005-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36449388c5903053f1e6cbac60140feb7b321a006e8493ed0f10480bca540d46
3
+ size 4832042328
model-00006-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3696ca372b23a411837f961c00f05a0eaf9a3035438f37525793753bd08823fa
3
+ size 4999848088
model-00007-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5481aa36ff6148bb35812b4de5bf2af12daf4e6d07a3106fd72d082088816cab
3
+ size 4832042352
model-00008-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:772db476e45a055871d22c837a70fbd2d4960dd075031f3d2b6b5a6ea0888db2
3
+ size 4832042352
model-00009-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb63af1e386f7ccf472c4d3108c134d863869ab13c4f9b945586a2994f3258d6
3
+ size 1744948136
model-00010-of-00010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c5a2c76f5b7960665d75ed473ae43eb76a5d86e27d8820ede0b9b16bb40b68c
3
+ size 3731831368
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_hyperclovax.py ADDED
@@ -0,0 +1,1866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import math
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPast,
35
+ CausalLMOutputWithPast,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutputWithPast,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
41
+ from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
43
+ from transformers.utils import (
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ is_flash_attn_greater_or_equal_2_10,
47
+ is_torchdynamo_compiling,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+
52
+ from .configuration_hyperclovax import HyperCLOVAXConfig
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ _CONFIG_FOR_DOC = "HyperCLOVAXConfig"
57
+
58
+
59
+ def _prepare_4d_causal_attention_mask_with_cache_position(
60
+ attention_mask: torch.Tensor,
61
+ sequence_length: int,
62
+ target_length: int,
63
+ dtype: torch.dtype,
64
+ device: torch.device,
65
+ min_dtype: float,
66
+ cache_position: torch.Tensor,
67
+ batch_size: int,
68
+ ):
69
+ """
70
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
71
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
72
+
73
+ Args:
74
+ attention_mask (`torch.Tensor`):
75
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
76
+ sequence_length (`int`):
77
+ The sequence length being processed.
78
+ target_length (`int`):
79
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
80
+ dtype (`torch.dtype`):
81
+ The dtype to use for the 4D attention mask.
82
+ device (`torch.device`):
83
+ The device to plcae the 4D attention mask on.
84
+ min_dtype (`float`):
85
+ The minimum value representable with the dtype `dtype`.
86
+ cache_position (`torch.Tensor`):
87
+ Indices depicting the position of the input sequence tokens in the sequence.
88
+ batch_size (`torch.Tensor`):
89
+ Batch size.
90
+ """
91
+ if attention_mask is not None and attention_mask.dim() == 4:
92
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
93
+ causal_mask = attention_mask
94
+ else:
95
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
96
+ if sequence_length != 1:
97
+ causal_mask = torch.triu(causal_mask, diagonal=1)
98
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
99
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
100
+ if attention_mask is not None:
101
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
102
+ mask_length = attention_mask.shape[-1]
103
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
104
+ padding_mask = padding_mask == 0
105
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
106
+
107
+ return causal_mask
108
+
109
+
110
+ class HyperCLOVAXRMSNorm(nn.Module):
111
+ def __init__(self, hidden_size, eps=1e-6):
112
+ """
113
+ HyperCLOVAXRMSNorm is equivalent to T5LayerNorm
114
+ """
115
+ super().__init__()
116
+ self.weight = nn.Parameter(torch.ones(hidden_size))
117
+ self.variance_epsilon = eps
118
+
119
+ def forward(self, hidden_states):
120
+ input_dtype = hidden_states.dtype
121
+ hidden_states = hidden_states.to(torch.float32)
122
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
123
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
124
+ return self.weight * hidden_states.to(input_dtype)
125
+
126
+ def extra_repr(self):
127
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
128
+
129
+
130
+ ALL_LAYERNORM_LAYERS.append(HyperCLOVAXRMSNorm)
131
+
132
+
133
+ class HyperCLOVAXRotaryEmbedding(nn.Module):
134
+ def __init__(
135
+ self,
136
+ dim=None,
137
+ max_position_embeddings=2048,
138
+ base=10000,
139
+ device=None,
140
+ scaling_factor=1.0,
141
+ rope_type="default",
142
+ config: Optional[HyperCLOVAXConfig] = None,
143
+ ):
144
+ super().__init__()
145
+ # TODO (joao): remove the `if` below, only used for BC
146
+ self.rope_kwargs = {}
147
+ if config is None:
148
+ logger.warning_once(
149
+ "`HyperCLOVAXRotaryEmbedding` can now be fully parameterized by passing the model config through the "
150
+ "`config` argument. All other arguments will be removed in v4.46"
151
+ )
152
+ self.rope_kwargs = {
153
+ "rope_type": rope_type,
154
+ "factor": scaling_factor,
155
+ "dim": dim,
156
+ "base": base,
157
+ "max_position_embeddings": max_position_embeddings,
158
+ }
159
+ self.rope_type = rope_type
160
+ self.max_seq_len_cached = max_position_embeddings
161
+ self.original_max_seq_len = max_position_embeddings
162
+ else:
163
+ # BC: "rope_type" was originally "type"
164
+ if config.rope_scaling is not None:
165
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
166
+ else:
167
+ self.rope_type = "default"
168
+ self.max_seq_len_cached = config.max_position_embeddings
169
+ self.original_max_seq_len = config.max_position_embeddings
170
+
171
+ self.config = config
172
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
173
+
174
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
175
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
176
+ self.original_inv_freq = self.inv_freq
177
+
178
+ def _dynamic_frequency_update(self, position_ids, device):
179
+ """
180
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
181
+ 1 - growing beyond the cached sequence length (allow scaling)
182
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
183
+ """
184
+ seq_len = torch.max(position_ids) + 1
185
+ if seq_len > self.max_seq_len_cached: # growth
186
+ inv_freq, self.attention_scaling = self.rope_init_fn(
187
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
188
+ )
189
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
190
+ self.max_seq_len_cached = seq_len
191
+
192
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
193
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
194
+ self.max_seq_len_cached = self.original_max_seq_len
195
+
196
+ @torch.no_grad()
197
+ def forward(self, x, position_ids):
198
+ if "dynamic" in self.rope_type:
199
+ self._dynamic_frequency_update(position_ids, device=x.device)
200
+
201
+ # Core RoPE block
202
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
203
+ position_ids_expanded = position_ids[:, None, :].float()
204
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
205
+ device_type = x.device.type
206
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
207
+ with torch.autocast(device_type=device_type, enabled=False):
208
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
209
+ emb = torch.cat((freqs, freqs), dim=-1)
210
+ cos = emb.cos()
211
+ sin = emb.sin()
212
+
213
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
214
+ cos = cos * self.attention_scaling
215
+ sin = sin * self.attention_scaling
216
+
217
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
218
+
219
+
220
+ class HyperCLOVAXLinearScalingRotaryEmbedding(HyperCLOVAXRotaryEmbedding):
221
+ """HyperCLOVAXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
222
+
223
+ def __init__(self, *args, **kwargs):
224
+ logger.warning_once(
225
+ "`HyperCLOVAXLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
226
+ "`HyperCLOVAXRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
227
+ )
228
+ kwargs["rope_type"] = "linear"
229
+ super().__init__(*args, **kwargs)
230
+
231
+
232
+ class HyperCLOVAXDynamicNTKScalingRotaryEmbedding(HyperCLOVAXRotaryEmbedding):
233
+ """HyperCLOVAXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
234
+
235
+ def __init__(self, *args, **kwargs):
236
+ logger.warning_once(
237
+ "`HyperCLOVAXDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
238
+ "`HyperCLOVAXRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
239
+ "__init__)."
240
+ )
241
+ kwargs["rope_type"] = "dynamic"
242
+ super().__init__(*args, **kwargs)
243
+
244
+
245
+ def rotate_half(x):
246
+ """Rotates half the hidden dims of the input."""
247
+ x1 = x[..., : x.shape[-1] // 2]
248
+ x2 = x[..., x.shape[-1] // 2 :]
249
+ return torch.cat((-x2, x1), dim=-1)
250
+
251
+
252
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
253
+ """Applies Rotary Position Embedding to the query and key tensors.
254
+
255
+ Args:
256
+ q (`torch.Tensor`): The query tensor.
257
+ k (`torch.Tensor`): The key tensor.
258
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
259
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
260
+ position_ids (`torch.Tensor`, *optional*):
261
+ Deprecated and unused.
262
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
263
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
264
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
265
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
266
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
267
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
268
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
269
+ Returns:
270
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
271
+ """
272
+ cos = cos.unsqueeze(unsqueeze_dim)
273
+ sin = sin.unsqueeze(unsqueeze_dim)
274
+ q_embed = (q * cos) + (rotate_half(q) * sin)
275
+ k_embed = (k * cos) + (rotate_half(k) * sin)
276
+ return q_embed, k_embed
277
+
278
+
279
+ class HyperCLOVAXMLP(nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.config = config
283
+ self.hidden_size = config.hidden_size
284
+ self.intermediate_size = config.intermediate_size
285
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
286
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
287
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
288
+ self.act_fn = ACT2FN[config.hidden_act]
289
+
290
+ def forward(self, x):
291
+ if self.config.pretraining_tp > 1:
292
+ slice = self.intermediate_size // self.config.pretraining_tp
293
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
294
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
295
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
296
+
297
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
298
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
299
+
300
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
301
+ down_proj = [
302
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
303
+ ]
304
+ down_proj = sum(down_proj)
305
+ else:
306
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
307
+
308
+ return down_proj
309
+
310
+
311
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
312
+ """
313
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
314
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
315
+ """
316
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
317
+ if n_rep == 1:
318
+ return hidden_states
319
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
320
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
321
+
322
+
323
+ class HyperCLOVAXAttention(nn.Module):
324
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
325
+
326
+ def __init__(self, config: HyperCLOVAXConfig, layer_idx: Optional[int] = None):
327
+ super().__init__()
328
+ self.config = config
329
+ self.layer_idx = layer_idx
330
+ if layer_idx is None:
331
+ logger.warning_once(
332
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
333
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
334
+ "when creating this class."
335
+ )
336
+
337
+ self.attention_dropout = config.attention_dropout
338
+ self.hidden_size = config.hidden_size
339
+ self.num_heads = config.num_attention_heads
340
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
341
+ self.num_key_value_heads = config.num_key_value_heads
342
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
343
+ self.max_position_embeddings = config.max_position_embeddings
344
+ self.rope_theta = config.rope_theta
345
+ self.is_causal = True
346
+
347
+ self.scaling = config.attention_multiplier
348
+
349
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
350
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
351
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
352
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
353
+
354
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
355
+ self.rotary_emb = HyperCLOVAXRotaryEmbedding(config=self.config)
356
+
357
+ def forward(
358
+ self,
359
+ hidden_states: torch.Tensor,
360
+ attention_mask: Optional[torch.Tensor] = None,
361
+ position_ids: Optional[torch.LongTensor] = None,
362
+ past_key_value: Optional[Cache] = None,
363
+ output_attentions: bool = False,
364
+ use_cache: bool = False,
365
+ cache_position: Optional[torch.LongTensor] = None,
366
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
367
+ **kwargs,
368
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
369
+ bsz, q_len, _ = hidden_states.size()
370
+
371
+ if self.config.pretraining_tp > 1:
372
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
373
+ query_slices = self.q_proj.weight.split(
374
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
375
+ )
376
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
377
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
378
+
379
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
380
+ query_states = torch.cat(query_states, dim=-1)
381
+
382
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
383
+ key_states = torch.cat(key_states, dim=-1)
384
+
385
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
386
+ value_states = torch.cat(value_states, dim=-1)
387
+
388
+ else:
389
+ query_states = self.q_proj(hidden_states)
390
+ key_states = self.k_proj(hidden_states)
391
+ value_states = self.v_proj(hidden_states)
392
+
393
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
394
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
395
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
396
+
397
+ if position_embeddings is None:
398
+ logger.warning_once(
399
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
400
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
401
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
402
+ "removed and `position_embeddings` will be mandatory."
403
+ )
404
+ cos, sin = self.rotary_emb(value_states, position_ids)
405
+ else:
406
+ cos, sin = position_embeddings
407
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
408
+
409
+ if past_key_value is not None:
410
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
411
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
412
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
413
+
414
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
415
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
416
+ # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling / math.sqrt(self.head_dim)
417
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
418
+
419
+ if attention_mask is not None: # no matter the length, we just slice it
420
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
421
+ attn_weights = attn_weights + causal_mask
422
+
423
+ # upcast attention to fp32
424
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
425
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
426
+ attn_output = torch.matmul(attn_weights, value_states)
427
+
428
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
429
+ raise ValueError(
430
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
431
+ f" {attn_output.size()}"
432
+ )
433
+
434
+ attn_output = attn_output.transpose(1, 2).contiguous()
435
+
436
+ attn_output = attn_output.reshape(bsz, q_len, -1)
437
+
438
+ if self.config.pretraining_tp > 1:
439
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
440
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
441
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
442
+ else:
443
+ attn_output = self.o_proj(attn_output)
444
+
445
+ if not output_attentions:
446
+ attn_weights = None
447
+
448
+ return attn_output, attn_weights, past_key_value
449
+
450
+
451
+ class HyperCLOVAXFlashAttention2(HyperCLOVAXAttention):
452
+ """
453
+ HyperCLOVAX flash attention module. This module inherits from `HyperCLOVAXAttention` as the weights of the module stays
454
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
455
+ flash attention and deal with padding tokens in case the input contains any of them.
456
+ """
457
+
458
+ def __init__(self, *args, **kwargs):
459
+ super().__init__(*args, **kwargs)
460
+
461
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
462
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
463
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
464
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
465
+
466
+ def forward(
467
+ self,
468
+ hidden_states: torch.Tensor,
469
+ attention_mask: Optional[torch.LongTensor] = None,
470
+ position_ids: Optional[torch.LongTensor] = None,
471
+ past_key_value: Optional[Cache] = None,
472
+ output_attentions: bool = False,
473
+ use_cache: bool = False,
474
+ cache_position: Optional[torch.LongTensor] = None,
475
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
476
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
477
+ if isinstance(past_key_value, StaticCache):
478
+ raise ValueError(
479
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
480
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
481
+ )
482
+
483
+ output_attentions = False
484
+
485
+ bsz, q_len, _ = hidden_states.size()
486
+
487
+ query_states = self.q_proj(hidden_states)
488
+ key_states = self.k_proj(hidden_states)
489
+ value_states = self.v_proj(hidden_states)
490
+
491
+ # Flash attention requires the input to have the shape
492
+ # batch_size x seq_length x head_dim x hidden_dim
493
+ # therefore we just need to keep the original shape
494
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
495
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
496
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
497
+
498
+ if position_embeddings is None:
499
+ logger.warning_once(
500
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
501
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
502
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
503
+ "removed and `position_embeddings` will be mandatory."
504
+ )
505
+ cos, sin = self.rotary_emb(value_states, position_ids)
506
+ else:
507
+ cos, sin = position_embeddings
508
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
509
+
510
+ if past_key_value is not None:
511
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
512
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
513
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
514
+
515
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
516
+ # to be able to avoid many of these transpose/reshape/view.
517
+ query_states = query_states.transpose(1, 2)
518
+ key_states = key_states.transpose(1, 2)
519
+ value_states = value_states.transpose(1, 2)
520
+
521
+ dropout_rate = self.attention_dropout if self.training else 0.0
522
+
523
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
524
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
525
+ # cast them back in the correct dtype just to be sure everything works as expected.
526
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
527
+ # in fp32. (HyperCLOVAXRMSNorm handles it correctly)
528
+
529
+ input_dtype = query_states.dtype
530
+ if input_dtype == torch.float32:
531
+ if torch.is_autocast_enabled():
532
+ target_dtype = torch.get_autocast_gpu_dtype()
533
+ # Handle the case where the model is quantized
534
+ elif hasattr(self.config, "_pre_quantization_dtype"):
535
+ target_dtype = self.config._pre_quantization_dtype
536
+ else:
537
+ target_dtype = self.q_proj.weight.dtype
538
+
539
+ logger.warning_once(
540
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
541
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
542
+ f" {target_dtype}."
543
+ )
544
+
545
+ query_states = query_states.to(target_dtype)
546
+ key_states = key_states.to(target_dtype)
547
+ value_states = value_states.to(target_dtype)
548
+
549
+ attn_output = _flash_attention_forward(
550
+ query_states,
551
+ key_states,
552
+ value_states,
553
+ attention_mask,
554
+ q_len,
555
+ position_ids=position_ids,
556
+ dropout=dropout_rate,
557
+ softmax_scale=self.scaling, # mup
558
+ sliding_window=getattr(self, "sliding_window", None),
559
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
560
+ is_causal=self.is_causal,
561
+ )
562
+
563
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
564
+ attn_output = self.o_proj(attn_output)
565
+
566
+ if not output_attentions:
567
+ attn_weights = None
568
+
569
+ return attn_output, attn_weights, past_key_value
570
+
571
+
572
+ class HyperCLOVAXSdpaAttention(HyperCLOVAXAttention):
573
+ """
574
+ HyperCLOVAX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
575
+ `HyperCLOVAXAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
576
+ SDPA API.
577
+ """
578
+
579
+ # Adapted from HyperCLOVAXAttention.forward
580
+ def forward(
581
+ self,
582
+ hidden_states: torch.Tensor,
583
+ attention_mask: Optional[torch.Tensor] = None,
584
+ position_ids: Optional[torch.LongTensor] = None,
585
+ past_key_value: Optional[Cache] = None,
586
+ output_attentions: bool = False,
587
+ use_cache: bool = False,
588
+ cache_position: Optional[torch.LongTensor] = None,
589
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
590
+ **kwargs,
591
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
592
+ if output_attentions:
593
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
594
+ logger.warning_once(
595
+ "HyperCLOVAXModel is using HyperCLOVAXSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
596
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
597
+ )
598
+ return super().forward(
599
+ hidden_states=hidden_states,
600
+ attention_mask=attention_mask,
601
+ position_ids=position_ids,
602
+ past_key_value=past_key_value,
603
+ output_attentions=output_attentions,
604
+ use_cache=use_cache,
605
+ cache_position=cache_position,
606
+ position_embeddings=position_embeddings,
607
+ )
608
+
609
+ bsz, q_len, _ = hidden_states.size()
610
+
611
+ query_states = self.q_proj(hidden_states)
612
+ key_states = self.k_proj(hidden_states)
613
+ value_states = self.v_proj(hidden_states)
614
+
615
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
616
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
617
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
618
+
619
+ if position_embeddings is None:
620
+ logger.warning_once(
621
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
622
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
623
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
624
+ "removed and `position_embeddings` will be mandatory."
625
+ )
626
+ cos, sin = self.rotary_emb(value_states, position_ids)
627
+ else:
628
+ cos, sin = position_embeddings
629
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
630
+
631
+ if past_key_value is not None:
632
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
633
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
634
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
635
+
636
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
637
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
638
+
639
+ causal_mask = attention_mask
640
+ if attention_mask is not None:
641
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
642
+
643
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
644
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
645
+ if query_states.device.type == "cuda" and causal_mask is not None:
646
+ query_states = query_states.contiguous()
647
+ key_states = key_states.contiguous()
648
+ value_states = value_states.contiguous()
649
+
650
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
651
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
652
+ is_causal = True if causal_mask is None and q_len > 1 else False
653
+
654
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
655
+ query_states,
656
+ key_states,
657
+ value_states,
658
+ attn_mask=causal_mask,
659
+ dropout_p=self.attention_dropout if self.training else 0.0,
660
+ is_causal=is_causal,
661
+ scale=self.scaling, # mup
662
+ )
663
+
664
+ attn_output = attn_output.transpose(1, 2).contiguous()
665
+ attn_output = attn_output.view(bsz, q_len, -1)
666
+
667
+ attn_output = self.o_proj(attn_output)
668
+
669
+ return attn_output, None, past_key_value
670
+
671
+
672
+ HyperCLOVAX_ATTENTION_CLASSES = {
673
+ "eager": HyperCLOVAXAttention,
674
+ "flash_attention_2": HyperCLOVAXFlashAttention2,
675
+ "sdpa": HyperCLOVAXSdpaAttention,
676
+ }
677
+
678
+
679
+ class HyperCLOVAXDecoderLayer(nn.Module):
680
+ def __init__(self, config: HyperCLOVAXConfig, layer_idx: int):
681
+ super().__init__()
682
+ self.hidden_size = config.hidden_size
683
+
684
+ self.self_attn = HyperCLOVAX_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
685
+
686
+ self.mlp = HyperCLOVAXMLP(config)
687
+ self.input_layernorm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
688
+ self.post_attention_layernorm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
689
+
690
+ # post-norm (dual-norm)
691
+ self.use_post_norm = config.use_post_norm
692
+ if self.use_post_norm:
693
+ self.post_norm1 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
694
+ self.post_norm2 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
695
+
696
+ self.residual_multiplier = config.residual_multiplier # mup
697
+
698
+ def forward(
699
+ self,
700
+ hidden_states: torch.Tensor,
701
+ attention_mask: Optional[torch.Tensor] = None,
702
+ position_ids: Optional[torch.LongTensor] = None,
703
+ past_key_value: Optional[Cache] = None,
704
+ output_attentions: Optional[bool] = False,
705
+ use_cache: Optional[bool] = False,
706
+ cache_position: Optional[torch.LongTensor] = None,
707
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
708
+ **kwargs,
709
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
710
+ """
711
+ Args:
712
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
713
+ attention_mask (`torch.FloatTensor`, *optional*):
714
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
715
+ query_sequence_length, key_sequence_length)` if default attention is used.
716
+ output_attentions (`bool`, *optional*):
717
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
718
+ returned tensors for more detail.
719
+ use_cache (`bool`, *optional*):
720
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
721
+ (see `past_key_values`).
722
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
723
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
724
+ Indices depicting the position of the input sequence tokens in the sequence
725
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
726
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
727
+ with `head_dim` being the embedding dimension of each attention head.
728
+ kwargs (`dict`, *optional*):
729
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
730
+ into the model
731
+ """
732
+ residual = hidden_states
733
+
734
+ hidden_states = self.input_layernorm(hidden_states)
735
+
736
+ # Self Attention
737
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
738
+ hidden_states=hidden_states,
739
+ attention_mask=attention_mask,
740
+ position_ids=position_ids,
741
+ past_key_value=past_key_value,
742
+ output_attentions=output_attentions,
743
+ use_cache=use_cache,
744
+ cache_position=cache_position,
745
+ position_embeddings=position_embeddings,
746
+ **kwargs,
747
+ )
748
+
749
+ if self.use_post_norm:
750
+ hidden_states = self.post_norm1(hidden_states)
751
+
752
+ hidden_states = residual + hidden_states * self.residual_multiplier # mup
753
+
754
+ # Fully Connected
755
+ residual = hidden_states
756
+ hidden_states = self.post_attention_layernorm(hidden_states)
757
+ hidden_states = self.mlp(hidden_states)
758
+
759
+ if self.use_post_norm:
760
+ hidden_states = self.post_norm2(hidden_states)
761
+
762
+ hidden_states = residual + hidden_states * self.residual_multiplier # mup
763
+
764
+ outputs = (hidden_states,)
765
+
766
+ if output_attentions:
767
+ outputs += (self_attn_weights,)
768
+
769
+ if use_cache:
770
+ outputs += (present_key_value,)
771
+
772
+ return outputs
773
+
774
+
775
+ HyperCLOVAX_START_DOCSTRING = r"""
776
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
777
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
778
+ etc.)
779
+
780
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
781
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
782
+ and behavior.
783
+
784
+ Parameters:
785
+ config ([`HyperCLOVAXConfig`]):
786
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
787
+ load the weights associated with the model, only the configuration. Check out the
788
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
789
+ """
790
+
791
+
792
+ @add_start_docstrings(
793
+ "The bare HyperCLOVAX Model outputting raw hidden-states without any specific head on top.",
794
+ HyperCLOVAX_START_DOCSTRING,
795
+ )
796
+ class HyperCLOVAXPreTrainedModel(PreTrainedModel):
797
+ config_class = HyperCLOVAXConfig
798
+ base_model_prefix = "model"
799
+ supports_gradient_checkpointing = True
800
+ _no_split_modules = ["HyperCLOVAXDecoderLayer"]
801
+ _skip_keys_device_placement = ["past_key_values"]
802
+ _supports_flash_attn_2 = True
803
+ _supports_sdpa = True
804
+ _supports_cache_class = True
805
+ _supports_quantized_cache = True
806
+ _supports_static_cache = True
807
+
808
+ def _init_weights(self, module):
809
+ std = self.config.initializer_range
810
+ if isinstance(module, nn.Linear):
811
+ module.weight.data.normal_(mean=0.0, std=std)
812
+ if module.bias is not None:
813
+ module.bias.data.zero_()
814
+ elif isinstance(module, nn.Embedding):
815
+ module.weight.data.normal_(mean=0.0, std=std)
816
+ if module.padding_idx is not None:
817
+ module.weight.data[module.padding_idx].zero_()
818
+
819
+
820
+ HyperCLOVAX_INPUTS_DOCSTRING = r"""
821
+ Args:
822
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
823
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
824
+ it.
825
+
826
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
827
+ [`PreTrainedTokenizer.__call__`] for details.
828
+
829
+ [What are input IDs?](../glossary#input-ids)
830
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
831
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
832
+
833
+ - 1 for tokens that are **not masked**,
834
+ - 0 for tokens that are **masked**.
835
+
836
+ [What are attention masks?](../glossary#attention-mask)
837
+
838
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
839
+ [`PreTrainedTokenizer.__call__`] for details.
840
+
841
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
842
+ `past_key_values`).
843
+
844
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
845
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
846
+ information on the default strategy.
847
+
848
+ - 1 indicates the head is **not masked**,
849
+ - 0 indicates the head is **masked**.
850
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
851
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
852
+ config.n_positions - 1]`.
853
+
854
+ [What are position IDs?](../glossary#position-ids)
855
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
856
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
857
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
858
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
859
+
860
+ Two formats are allowed:
861
+ - a [`~cache_utils.Cache`] instance, see our
862
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
863
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
864
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
865
+ cache format.
866
+
867
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
868
+ legacy cache format will be returned.
869
+
870
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
871
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
872
+ of shape `(batch_size, sequence_length)`.
873
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
874
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
875
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
876
+ model's internal embedding lookup matrix.
877
+ use_cache (`bool`, *optional*):
878
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
879
+ `past_key_values`).
880
+ output_attentions (`bool`, *optional*):
881
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
882
+ tensors for more detail.
883
+ output_hidden_states (`bool`, *optional*):
884
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
885
+ more detail.
886
+ return_dict (`bool`, *optional*):
887
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
888
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
889
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
890
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
891
+ the complete sequence length.
892
+ """
893
+
894
+
895
+ @add_start_docstrings(
896
+ "The bare HyperCLOVAX Model outputting raw hidden-states without any specific head on top.",
897
+ HyperCLOVAX_START_DOCSTRING,
898
+ )
899
+ class HyperCLOVAXModel(HyperCLOVAXPreTrainedModel):
900
+ """
901
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HyperCLOVAXDecoderLayer`]
902
+
903
+ Args:
904
+ config: HyperCLOVAXConfig
905
+ """
906
+
907
+ def __init__(self, config: HyperCLOVAXConfig):
908
+ super().__init__(config)
909
+ self.padding_idx = config.pad_token_id
910
+ self.vocab_size = config.vocab_size
911
+
912
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
913
+ self.layers = nn.ModuleList(
914
+ [HyperCLOVAXDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
915
+ )
916
+ self.norm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
917
+ self.rotary_emb = HyperCLOVAXRotaryEmbedding(config=config)
918
+ self.gradient_checkpointing = False
919
+
920
+ # Initialize weights and apply final processing
921
+ self.post_init()
922
+
923
+ # mup
924
+ self.embedding_multiplier = config.embedding_multiplier
925
+
926
+ def get_input_embeddings(self):
927
+ return self.embed_tokens
928
+
929
+ def set_input_embeddings(self, value):
930
+ self.embed_tokens = value
931
+
932
+ @add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
933
+ def forward(
934
+ self,
935
+ input_ids: torch.LongTensor = None,
936
+ attention_mask: Optional[torch.Tensor] = None,
937
+ position_ids: Optional[torch.LongTensor] = None,
938
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
939
+ inputs_embeds: Optional[torch.FloatTensor] = None,
940
+ use_cache: Optional[bool] = None,
941
+ output_attentions: Optional[bool] = None,
942
+ output_hidden_states: Optional[bool] = None,
943
+ return_dict: Optional[bool] = None,
944
+ cache_position: Optional[torch.LongTensor] = None,
945
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
946
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
947
+ output_hidden_states = (
948
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
949
+ )
950
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
951
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
952
+
953
+ if (input_ids is None) ^ (inputs_embeds is not None):
954
+ raise ValueError(
955
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
956
+ )
957
+
958
+ if self.gradient_checkpointing and self.training and use_cache:
959
+ logger.warning_once(
960
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
961
+ )
962
+ use_cache = False
963
+
964
+ if inputs_embeds is None:
965
+ inputs_embeds = self.embed_tokens(input_ids)
966
+
967
+ inputs_embeds = inputs_embeds * self.embedding_multiplier # mup
968
+
969
+ # kept for BC (non `Cache` `past_key_values` inputs)
970
+ return_legacy_cache = False
971
+ if use_cache and not isinstance(past_key_values, Cache):
972
+ return_legacy_cache = True
973
+ if past_key_values is None:
974
+ past_key_values = DynamicCache()
975
+ else:
976
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
977
+ logger.warning_once(
978
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
979
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
980
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
981
+ )
982
+
983
+ if cache_position is None:
984
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
985
+ cache_position = torch.arange(
986
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
987
+ )
988
+ if position_ids is None:
989
+ position_ids = cache_position.unsqueeze(0)
990
+
991
+ causal_mask = self._update_causal_mask(
992
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
993
+ )
994
+ hidden_states = inputs_embeds
995
+
996
+ # create position embeddings to be shared across the decoder layers
997
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
998
+
999
+ # decoder layers
1000
+ all_hidden_states = () if output_hidden_states else None
1001
+ all_self_attns = () if output_attentions else None
1002
+ next_decoder_cache = None
1003
+
1004
+ for decoder_layer in self.layers:
1005
+ if output_hidden_states:
1006
+ all_hidden_states += (hidden_states,)
1007
+
1008
+ if self.gradient_checkpointing and self.training:
1009
+ layer_outputs = self._gradient_checkpointing_func(
1010
+ decoder_layer.__call__,
1011
+ hidden_states,
1012
+ causal_mask,
1013
+ position_ids,
1014
+ past_key_values,
1015
+ output_attentions,
1016
+ use_cache,
1017
+ cache_position,
1018
+ position_embeddings,
1019
+ )
1020
+ else:
1021
+ layer_outputs = decoder_layer(
1022
+ hidden_states,
1023
+ attention_mask=causal_mask,
1024
+ position_ids=position_ids,
1025
+ past_key_value=past_key_values,
1026
+ output_attentions=output_attentions,
1027
+ use_cache=use_cache,
1028
+ cache_position=cache_position,
1029
+ position_embeddings=position_embeddings,
1030
+ )
1031
+
1032
+ hidden_states = layer_outputs[0]
1033
+
1034
+ if use_cache:
1035
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1036
+
1037
+ if output_attentions:
1038
+ all_self_attns += (layer_outputs[1],)
1039
+
1040
+ hidden_states = self.norm(hidden_states)
1041
+
1042
+ # add hidden states from the last decoder layer
1043
+ if output_hidden_states:
1044
+ all_hidden_states += (hidden_states,)
1045
+
1046
+ next_cache = next_decoder_cache if use_cache else None
1047
+ if return_legacy_cache:
1048
+ next_cache = next_cache.to_legacy_cache()
1049
+
1050
+ if not return_dict:
1051
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1052
+ return BaseModelOutputWithPast(
1053
+ last_hidden_state=hidden_states,
1054
+ past_key_values=next_cache,
1055
+ hidden_states=all_hidden_states,
1056
+ attentions=all_self_attns,
1057
+ )
1058
+
1059
+ def _update_causal_mask(
1060
+ self,
1061
+ attention_mask: torch.Tensor,
1062
+ input_tensor: torch.Tensor,
1063
+ cache_position: torch.Tensor,
1064
+ past_key_values: Cache,
1065
+ output_attentions: bool,
1066
+ ):
1067
+ if self.config._attn_implementation == "flash_attention_2":
1068
+ if attention_mask is not None and 0.0 in attention_mask:
1069
+ return attention_mask
1070
+ return None
1071
+
1072
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1073
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1074
+ # to infer the attention mask.
1075
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1076
+ using_static_cache = isinstance(past_key_values, StaticCache)
1077
+
1078
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1079
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1080
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1081
+ attention_mask,
1082
+ inputs_embeds=input_tensor,
1083
+ past_key_values_length=past_seen_tokens,
1084
+ is_training=self.training,
1085
+ ):
1086
+ return None
1087
+
1088
+ dtype, device = input_tensor.dtype, input_tensor.device
1089
+ min_dtype = torch.finfo(dtype).min
1090
+ sequence_length = input_tensor.shape[1]
1091
+ if using_static_cache:
1092
+ target_length = past_key_values.get_max_length()
1093
+ else:
1094
+ target_length = (
1095
+ attention_mask.shape[-1]
1096
+ if isinstance(attention_mask, torch.Tensor)
1097
+ else past_seen_tokens + sequence_length + 1
1098
+ )
1099
+
1100
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1101
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1102
+ attention_mask,
1103
+ sequence_length=sequence_length,
1104
+ target_length=target_length,
1105
+ dtype=dtype,
1106
+ device=device,
1107
+ min_dtype=min_dtype,
1108
+ cache_position=cache_position,
1109
+ batch_size=input_tensor.shape[0],
1110
+ )
1111
+
1112
+ if (
1113
+ self.config._attn_implementation == "sdpa"
1114
+ and attention_mask is not None
1115
+ and attention_mask.device.type == "cuda"
1116
+ and not output_attentions
1117
+ ):
1118
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1119
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1120
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1121
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1122
+
1123
+ return causal_mask
1124
+
1125
+
1126
+ class HyperCLOVAXForCausalLM(HyperCLOVAXPreTrainedModel, GenerationMixin):
1127
+ _tied_weights_keys = ["lm_head.weight"]
1128
+
1129
+ def __init__(self, config):
1130
+ super().__init__(config)
1131
+ self.model = HyperCLOVAXModel(config)
1132
+ self.vocab_size = config.vocab_size
1133
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1134
+
1135
+ # Initialize weights and apply final processing
1136
+ self.post_init()
1137
+
1138
+ def _get_apply_liger_kernel_converter(self):
1139
+ return _apply_liger_kernel_to_instance
1140
+
1141
+ def get_input_embeddings(self):
1142
+ return self.model.embed_tokens
1143
+
1144
+ def set_input_embeddings(self, value):
1145
+ self.model.embed_tokens = value
1146
+
1147
+ def get_output_embeddings(self):
1148
+ return self.lm_head
1149
+
1150
+ def set_output_embeddings(self, new_embeddings):
1151
+ self.lm_head = new_embeddings
1152
+
1153
+ def set_decoder(self, decoder):
1154
+ self.model = decoder
1155
+
1156
+ def get_decoder(self):
1157
+ return self.model
1158
+
1159
+ @add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
1160
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1161
+ def forward(
1162
+ self,
1163
+ input_ids: torch.LongTensor = None,
1164
+ attention_mask: Optional[torch.Tensor] = None,
1165
+ position_ids: Optional[torch.LongTensor] = None,
1166
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1167
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1168
+ labels: Optional[torch.LongTensor] = None,
1169
+ use_cache: Optional[bool] = None,
1170
+ output_attentions: Optional[bool] = None,
1171
+ output_hidden_states: Optional[bool] = None,
1172
+ return_dict: Optional[bool] = None,
1173
+ cache_position: Optional[torch.LongTensor] = None,
1174
+ num_logits_to_keep: int = 0,
1175
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1176
+ r"""
1177
+ Args:
1178
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1179
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1180
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1181
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1182
+
1183
+ num_logits_to_keep (`int`, *optional*):
1184
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1185
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1186
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1187
+
1188
+ Returns:
1189
+
1190
+ Example:
1191
+
1192
+ ```python
1193
+ >>> from transformers import AutoTokenizer, HyperCLOVAXForCausalLM
1194
+
1195
+ >>> model = HyperCLOVAXForCausalLM.from_pretrained(YOUR_DIR)
1196
+ >>> tokenizer = AutoTokenizer.from_pretrained(YOUR_DIR)
1197
+
1198
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1199
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1200
+
1201
+ >>> # Generate
1202
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1203
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1204
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1205
+ ```"""
1206
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1207
+ output_hidden_states = (
1208
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1209
+ )
1210
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1211
+
1212
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1213
+ outputs = self.model(
1214
+ input_ids=input_ids,
1215
+ attention_mask=attention_mask,
1216
+ position_ids=position_ids,
1217
+ past_key_values=past_key_values,
1218
+ inputs_embeds=inputs_embeds,
1219
+ use_cache=use_cache,
1220
+ output_attentions=output_attentions,
1221
+ output_hidden_states=output_hidden_states,
1222
+ return_dict=return_dict,
1223
+ cache_position=cache_position,
1224
+ )
1225
+
1226
+ hidden_states = outputs[0]
1227
+ if self.config.pretraining_tp > 1:
1228
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1229
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1230
+ logits = torch.cat(logits, dim=-1)
1231
+ else:
1232
+ if labels is None and not is_torchdynamo_compiling():
1233
+ logger.warning_once(
1234
+ "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
1235
+ )
1236
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1237
+ # TODO: remove the float() operation in v4.46
1238
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1239
+
1240
+ logits = logits * self.config.logits_scaling # mup
1241
+
1242
+ loss = None
1243
+ if labels is not None:
1244
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1245
+ logits = logits.float()
1246
+ # Shift so that tokens < n predict n
1247
+ shift_logits = logits[..., :-1, :].contiguous()
1248
+ shift_labels = labels[..., 1:].contiguous()
1249
+ # Flatten the tokens
1250
+ loss_fct = CrossEntropyLoss()
1251
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1252
+ shift_labels = shift_labels.view(-1)
1253
+ # Enable model parallelism
1254
+ shift_labels = shift_labels.to(shift_logits.device)
1255
+ loss = loss_fct(shift_logits, shift_labels)
1256
+
1257
+ if not return_dict:
1258
+ output = (logits,) + outputs[1:]
1259
+ return (loss,) + output if loss is not None else output
1260
+
1261
+ return CausalLMOutputWithPast(
1262
+ loss=loss,
1263
+ logits=logits,
1264
+ past_key_values=outputs.past_key_values,
1265
+ hidden_states=outputs.hidden_states,
1266
+ attentions=outputs.attentions,
1267
+ )
1268
+
1269
+ def prepare_inputs_for_generation(
1270
+ self,
1271
+ input_ids,
1272
+ past_key_values=None,
1273
+ attention_mask=None,
1274
+ inputs_embeds=None,
1275
+ cache_position=None,
1276
+ position_ids=None,
1277
+ use_cache=True,
1278
+ num_logits_to_keep=None,
1279
+ **kwargs,
1280
+ ):
1281
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1282
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1283
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1284
+ if past_key_values is not None:
1285
+ if inputs_embeds is not None: # Exception 1
1286
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1287
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1288
+ input_ids = input_ids[:, cache_position]
1289
+
1290
+ if attention_mask is not None and position_ids is None:
1291
+ # create position_ids on the fly for batch generation
1292
+ position_ids = attention_mask.long().cumsum(-1) - 1
1293
+ position_ids.masked_fill_(attention_mask == 0, 1)
1294
+ if past_key_values:
1295
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1296
+
1297
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1298
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1299
+
1300
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1301
+ if inputs_embeds is not None and cache_position[0] == 0:
1302
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1303
+ else:
1304
+ # The clone here is for the same reason as for `position_ids`.
1305
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
1306
+
1307
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1308
+ if model_inputs["inputs_embeds"] is not None:
1309
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1310
+ device = model_inputs["inputs_embeds"].device
1311
+ else:
1312
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1313
+ device = model_inputs["input_ids"].device
1314
+
1315
+ dtype = self.lm_head.weight.dtype
1316
+ min_dtype = torch.finfo(dtype).min
1317
+
1318
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1319
+ attention_mask,
1320
+ sequence_length=sequence_length,
1321
+ target_length=past_key_values.get_max_length(),
1322
+ dtype=dtype,
1323
+ device=device,
1324
+ min_dtype=min_dtype,
1325
+ cache_position=cache_position,
1326
+ batch_size=batch_size,
1327
+ )
1328
+
1329
+ if num_logits_to_keep is not None:
1330
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
1331
+
1332
+ model_inputs.update(
1333
+ {
1334
+ "position_ids": position_ids,
1335
+ "cache_position": cache_position,
1336
+ "past_key_values": past_key_values,
1337
+ "use_cache": use_cache,
1338
+ "attention_mask": attention_mask,
1339
+ }
1340
+ )
1341
+ return model_inputs
1342
+
1343
+
1344
+ @add_start_docstrings(
1345
+ """
1346
+ The HyperCLOVAX Model transformer with a sequence classification head on top (linear layer).
1347
+
1348
+ [`HyperCLOVAXForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1349
+ (e.g. GPT-2) do.
1350
+
1351
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1352
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1353
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1354
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1355
+ each row of the batch).
1356
+ """,
1357
+ HyperCLOVAX_START_DOCSTRING,
1358
+ )
1359
+ class HyperCLOVAXForSequenceClassification(HyperCLOVAXPreTrainedModel):
1360
+ def __init__(self, config):
1361
+ super().__init__(config)
1362
+ self.num_labels = config.num_labels
1363
+ self.model = HyperCLOVAXModel(config)
1364
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1365
+
1366
+ # Initialize weights and apply final processing
1367
+ self.post_init()
1368
+
1369
+ def get_input_embeddings(self):
1370
+ return self.model.embed_tokens
1371
+
1372
+ def set_input_embeddings(self, value):
1373
+ self.model.embed_tokens = value
1374
+
1375
+ @add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
1376
+ def forward(
1377
+ self,
1378
+ input_ids: Optional[torch.LongTensor] = None,
1379
+ attention_mask: Optional[torch.Tensor] = None,
1380
+ position_ids: Optional[torch.LongTensor] = None,
1381
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1382
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1383
+ labels: Optional[torch.LongTensor] = None,
1384
+ use_cache: Optional[bool] = None,
1385
+ output_attentions: Optional[bool] = None,
1386
+ output_hidden_states: Optional[bool] = None,
1387
+ return_dict: Optional[bool] = None,
1388
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1389
+ r"""
1390
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1391
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1392
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1393
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1394
+ """
1395
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1396
+
1397
+ transformer_outputs = self.model(
1398
+ input_ids,
1399
+ attention_mask=attention_mask,
1400
+ position_ids=position_ids,
1401
+ past_key_values=past_key_values,
1402
+ inputs_embeds=inputs_embeds,
1403
+ use_cache=use_cache,
1404
+ output_attentions=output_attentions,
1405
+ output_hidden_states=output_hidden_states,
1406
+ return_dict=return_dict,
1407
+ )
1408
+ hidden_states = transformer_outputs[0]
1409
+ logits = self.score(hidden_states)
1410
+
1411
+ if input_ids is not None:
1412
+ batch_size = input_ids.shape[0]
1413
+ else:
1414
+ batch_size = inputs_embeds.shape[0]
1415
+
1416
+ if self.config.pad_token_id is None and batch_size != 1:
1417
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1418
+ if self.config.pad_token_id is None:
1419
+ sequence_lengths = -1
1420
+ else:
1421
+ if input_ids is not None:
1422
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1423
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1424
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1425
+ sequence_lengths = sequence_lengths.to(logits.device)
1426
+ else:
1427
+ sequence_lengths = -1
1428
+
1429
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1430
+
1431
+ loss = None
1432
+ if labels is not None:
1433
+ labels = labels.to(logits.device)
1434
+ if self.config.problem_type is None:
1435
+ if self.num_labels == 1:
1436
+ self.config.problem_type = "regression"
1437
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1438
+ self.config.problem_type = "single_label_classification"
1439
+ else:
1440
+ self.config.problem_type = "multi_label_classification"
1441
+
1442
+ if self.config.problem_type == "regression":
1443
+ loss_fct = MSELoss()
1444
+ if self.num_labels == 1:
1445
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1446
+ else:
1447
+ loss = loss_fct(pooled_logits, labels)
1448
+ elif self.config.problem_type == "single_label_classification":
1449
+ loss_fct = CrossEntropyLoss()
1450
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1451
+ elif self.config.problem_type == "multi_label_classification":
1452
+ loss_fct = BCEWithLogitsLoss()
1453
+ loss = loss_fct(pooled_logits, labels)
1454
+ if not return_dict:
1455
+ output = (pooled_logits,) + transformer_outputs[1:]
1456
+ return ((loss,) + output) if loss is not None else output
1457
+
1458
+ return SequenceClassifierOutputWithPast(
1459
+ loss=loss,
1460
+ logits=pooled_logits,
1461
+ past_key_values=transformer_outputs.past_key_values,
1462
+ hidden_states=transformer_outputs.hidden_states,
1463
+ attentions=transformer_outputs.attentions,
1464
+ )
1465
+
1466
+
1467
+ @add_start_docstrings(
1468
+ """
1469
+ The HyperCLOVAX Model transformer with a span classification head on top for extractive question-answering tasks like
1470
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1471
+ """,
1472
+ HyperCLOVAX_START_DOCSTRING,
1473
+ )
1474
+ class HyperCLOVAXForQuestionAnswering(HyperCLOVAXPreTrainedModel):
1475
+ base_model_prefix = "transformer"
1476
+
1477
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->HyperCLOVAX
1478
+ def __init__(self, config):
1479
+ super().__init__(config)
1480
+ self.transformer = HyperCLOVAXModel(config)
1481
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1482
+
1483
+ # Initialize weights and apply final processing
1484
+ self.post_init()
1485
+
1486
+ def get_input_embeddings(self):
1487
+ return self.transformer.embed_tokens
1488
+
1489
+ def set_input_embeddings(self, value):
1490
+ self.transformer.embed_tokens = value
1491
+
1492
+ @add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
1493
+ def forward(
1494
+ self,
1495
+ input_ids: Optional[torch.LongTensor] = None,
1496
+ attention_mask: Optional[torch.FloatTensor] = None,
1497
+ position_ids: Optional[torch.LongTensor] = None,
1498
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1499
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1500
+ start_positions: Optional[torch.LongTensor] = None,
1501
+ end_positions: Optional[torch.LongTensor] = None,
1502
+ output_attentions: Optional[bool] = None,
1503
+ output_hidden_states: Optional[bool] = None,
1504
+ return_dict: Optional[bool] = None,
1505
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1506
+ r"""
1507
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1508
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1509
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1510
+ are not taken into account for computing the loss.
1511
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1512
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1513
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1514
+ are not taken into account for computing the loss.
1515
+ """
1516
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1517
+
1518
+ outputs = self.transformer(
1519
+ input_ids,
1520
+ attention_mask=attention_mask,
1521
+ position_ids=position_ids,
1522
+ past_key_values=past_key_values,
1523
+ inputs_embeds=inputs_embeds,
1524
+ output_attentions=output_attentions,
1525
+ output_hidden_states=output_hidden_states,
1526
+ return_dict=return_dict,
1527
+ )
1528
+
1529
+ sequence_output = outputs[0]
1530
+
1531
+ logits = self.qa_outputs(sequence_output)
1532
+ start_logits, end_logits = logits.split(1, dim=-1)
1533
+ start_logits = start_logits.squeeze(-1).contiguous()
1534
+ end_logits = end_logits.squeeze(-1).contiguous()
1535
+
1536
+ total_loss = None
1537
+ if start_positions is not None and end_positions is not None:
1538
+ # If we are on multi-GPU, split add a dimension
1539
+ if len(start_positions.size()) > 1:
1540
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1541
+ if len(end_positions.size()) > 1:
1542
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1543
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1544
+ ignored_index = start_logits.size(1)
1545
+ start_positions = start_positions.clamp(0, ignored_index)
1546
+ end_positions = end_positions.clamp(0, ignored_index)
1547
+
1548
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1549
+ start_loss = loss_fct(start_logits, start_positions)
1550
+ end_loss = loss_fct(end_logits, end_positions)
1551
+ total_loss = (start_loss + end_loss) / 2
1552
+
1553
+ if not return_dict:
1554
+ output = (start_logits, end_logits) + outputs[2:]
1555
+ return ((total_loss,) + output) if total_loss is not None else output
1556
+
1557
+ return QuestionAnsweringModelOutput(
1558
+ loss=total_loss,
1559
+ start_logits=start_logits,
1560
+ end_logits=end_logits,
1561
+ hidden_states=outputs.hidden_states,
1562
+ attentions=outputs.attentions,
1563
+ )
1564
+
1565
+
1566
+ @add_start_docstrings(
1567
+ """
1568
+ The HyperCLOVAX Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1569
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1570
+ """,
1571
+ HyperCLOVAX_START_DOCSTRING,
1572
+ )
1573
+ class HyperCLOVAXForTokenClassification(HyperCLOVAXPreTrainedModel):
1574
+ def __init__(self, config):
1575
+ super().__init__(config)
1576
+ self.num_labels = config.num_labels
1577
+ self.model = HyperCLOVAXModel(config)
1578
+ if getattr(config, "classifier_dropout", None) is not None:
1579
+ classifier_dropout = config.classifier_dropout
1580
+ elif getattr(config, "hidden_dropout", None) is not None:
1581
+ classifier_dropout = config.hidden_dropout
1582
+ else:
1583
+ classifier_dropout = 0.1
1584
+ self.dropout = nn.Dropout(classifier_dropout)
1585
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1586
+
1587
+ # Initialize weights and apply final processing
1588
+ self.post_init()
1589
+
1590
+ def get_input_embeddings(self):
1591
+ return self.model.embed_tokens
1592
+
1593
+ def set_input_embeddings(self, value):
1594
+ self.model.embed_tokens = value
1595
+
1596
+ @add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
1597
+ def forward(
1598
+ self,
1599
+ input_ids: Optional[torch.LongTensor] = None,
1600
+ attention_mask: Optional[torch.Tensor] = None,
1601
+ position_ids: Optional[torch.LongTensor] = None,
1602
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1603
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1604
+ labels: Optional[torch.LongTensor] = None,
1605
+ use_cache: Optional[bool] = None,
1606
+ output_attentions: Optional[bool] = None,
1607
+ output_hidden_states: Optional[bool] = None,
1608
+ return_dict: Optional[bool] = None,
1609
+ ) -> Union[Tuple, TokenClassifierOutput]:
1610
+ r"""
1611
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1612
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1613
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1614
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1615
+ """
1616
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1617
+
1618
+ outputs = self.model(
1619
+ input_ids,
1620
+ attention_mask=attention_mask,
1621
+ position_ids=position_ids,
1622
+ past_key_values=past_key_values,
1623
+ inputs_embeds=inputs_embeds,
1624
+ use_cache=use_cache,
1625
+ output_attentions=output_attentions,
1626
+ output_hidden_states=output_hidden_states,
1627
+ return_dict=return_dict,
1628
+ )
1629
+ sequence_output = outputs[0]
1630
+ sequence_output = self.dropout(sequence_output)
1631
+ logits = self.score(sequence_output)
1632
+
1633
+ loss = None
1634
+ if labels is not None:
1635
+ loss_fct = CrossEntropyLoss()
1636
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1637
+
1638
+ if not return_dict:
1639
+ output = (logits,) + outputs[2:]
1640
+ return ((loss,) + output) if loss is not None else output
1641
+
1642
+ return TokenClassifierOutput(
1643
+ loss=loss,
1644
+ logits=logits,
1645
+ hidden_states=outputs.hidden_states,
1646
+ attentions=outputs.attentions,
1647
+ )
1648
+
1649
+
1650
+ ################################################################################################
1651
+ ################################################################################################
1652
+ """
1653
+ liger kernel monkey patching
1654
+ https://github.com/linkedin/Liger-Kernel/blob/v0.5.2/src/liger_kernel/transformers/monkey_patch.py
1655
+ """
1656
+
1657
+ import inspect
1658
+ import logging
1659
+ from functools import partial
1660
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
1661
+
1662
+ import torch
1663
+ import torch.nn.functional as F
1664
+ import transformers
1665
+ from packaging import version
1666
+ from torch.nn import CrossEntropyLoss
1667
+ from transformers import PreTrainedModel
1668
+
1669
+ if TYPE_CHECKING:
1670
+ from transformers.cache_utils import Cache
1671
+
1672
+ import sys
1673
+
1674
+ from packaging.version import parse
1675
+
1676
+ if sys.version_info < (3, 8):
1677
+ import importlib_metadata
1678
+ else:
1679
+ import importlib.metadata as importlib_metadata
1680
+
1681
+ try:
1682
+ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
1683
+ from liger_kernel.transformers.functional import liger_cross_entropy
1684
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
1685
+ LigerFusedLinearCrossEntropyLoss,
1686
+ )
1687
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
1688
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb
1689
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
1690
+
1691
+ _is_liger_kernel_available = True
1692
+
1693
+ LIGER_KERNEL_MATCHING_VERSION = parse("0.5.2")
1694
+ liger_kernel_version = parse(importlib_metadata.version("liger_kernel"))
1695
+ _is_liger_kernel_version_matching = (
1696
+ liger_kernel_version.major,
1697
+ liger_kernel_version.minor,
1698
+ liger_kernel_version.release[-1],
1699
+ ) == (
1700
+ LIGER_KERNEL_MATCHING_VERSION.major,
1701
+ LIGER_KERNEL_MATCHING_VERSION.minor,
1702
+ LIGER_KERNEL_MATCHING_VERSION.release[-1],
1703
+ )
1704
+ except Exception:
1705
+ _is_liger_kernel_available = False
1706
+ _is_liger_kernel_version_matching = False
1707
+
1708
+
1709
+ def lce_forward_deprecated(
1710
+ self,
1711
+ input_ids: torch.LongTensor = None,
1712
+ attention_mask: Optional[torch.Tensor] = None,
1713
+ position_ids: Optional[torch.LongTensor] = None,
1714
+ past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
1715
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1716
+ labels: Optional[torch.LongTensor] = None,
1717
+ use_cache: Optional[bool] = None,
1718
+ output_attentions: Optional[bool] = None,
1719
+ output_hidden_states: Optional[bool] = None,
1720
+ return_dict: Optional[bool] = None,
1721
+ cache_position: Optional[torch.LongTensor] = None,
1722
+ num_logits_to_keep: int = 0,
1723
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1724
+
1725
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1726
+ output_hidden_states = (
1727
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1728
+ )
1729
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1730
+
1731
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1732
+ outputs = self.model(
1733
+ input_ids=input_ids,
1734
+ attention_mask=attention_mask,
1735
+ position_ids=position_ids,
1736
+ past_key_values=past_key_values,
1737
+ inputs_embeds=inputs_embeds,
1738
+ use_cache=use_cache,
1739
+ output_attentions=output_attentions,
1740
+ output_hidden_states=output_hidden_states,
1741
+ return_dict=return_dict,
1742
+ cache_position=cache_position,
1743
+ )
1744
+ hidden_states = outputs[0]
1745
+
1746
+ loss = None
1747
+ logits = None
1748
+
1749
+ if self.training and (labels is not None):
1750
+ if num_logits_to_keep != 0:
1751
+ hidden_states = hidden_states[:, -num_logits_to_keep:, :] # not sure if it has bug
1752
+ hidden_states = hidden_states * self.config.logits_scaling ## muP
1753
+
1754
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
1755
+ shift_labels = labels[..., 1:].contiguous()
1756
+
1757
+ # flatten tokens
1758
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
1759
+ shift_labels = shift_labels.view(-1)
1760
+
1761
+ lce = LigerFusedLinearCrossEntropyLoss()
1762
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
1763
+
1764
+ else:
1765
+ assert self.config.pretraining_tp == 1, "not supported"
1766
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1767
+ logits = logits * self.config.logits_scaling ## muP
1768
+
1769
+ if labels is not None:
1770
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1771
+ logits = logits.float()
1772
+ # Shift so that tokens < n predict n
1773
+ shift_logits = logits[..., :-1, :].contiguous()
1774
+ shift_labels = labels[..., 1:].contiguous()
1775
+ # Flatten the tokens
1776
+ loss_fct = CrossEntropyLoss()
1777
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1778
+ shift_labels = shift_labels.view(-1)
1779
+ # Enable model parallelism
1780
+ shift_labels = shift_labels.to(shift_logits.device)
1781
+ loss = loss_fct(shift_logits, shift_labels)
1782
+
1783
+ if not return_dict:
1784
+ output = (logits,) + outputs[1:]
1785
+ return (loss,) + output if loss is not None else output
1786
+
1787
+ return CausalLMOutputWithPast(
1788
+ loss=loss,
1789
+ logits=logits,
1790
+ past_key_values=outputs.past_key_values,
1791
+ hidden_states=outputs.hidden_states,
1792
+ attentions=outputs.attentions,
1793
+ )
1794
+
1795
+
1796
+ def _bind_method_to_module(module, method_name: str, new_method: Callable):
1797
+ # Binds a new method to a module instance so that self is passed as the first argument
1798
+ module.__dict__[method_name] = new_method.__get__(module, module.__class__)
1799
+
1800
+
1801
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
1802
+ module.offset = offset
1803
+ module.casting_mode = casting_mode
1804
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
1805
+ module.in_place = in_place
1806
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
1807
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
1808
+
1809
+
1810
+ def apply_liger_kernel_to_hyperclovax(
1811
+ rope: bool = True,
1812
+ cross_entropy: bool = False,
1813
+ fused_linear_cross_entropy: bool = True,
1814
+ rms_norm: bool = True,
1815
+ swiglu: bool = True,
1816
+ model: PreTrainedModel = None,
1817
+ ) -> None:
1818
+
1819
+ assert not cross_entropy, "not supported"
1820
+ if rope:
1821
+ apply_rotary_pos_emb = liger_rotary_pos_emb
1822
+ if rms_norm:
1823
+ HyperCLOVAXRMSNorm = LigerRMSNorm
1824
+ if swiglu:
1825
+ HyperCLOVAXMLP = LigerSwiGLUMLP
1826
+ # to use VLM forward in VLM repo
1827
+ # if fused_linear_cross_entropy:
1828
+ # HyperCLOVAXForCausalLM.forward = lce_forward_deprecated
1829
+
1830
+ if model is not None:
1831
+ # The model instance already exists, so we need to additionally patch the
1832
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
1833
+
1834
+ # get the base model from the model instance
1835
+ base_model: HyperCLOVAXModel = getattr(model, model.base_model_prefix, model)
1836
+
1837
+ if rms_norm:
1838
+ _patch_rms_norm_module(base_model.norm)
1839
+
1840
+ for decoder_layer in base_model.layers:
1841
+ if swiglu:
1842
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1843
+ if rms_norm:
1844
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1845
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1846
+ if decoder_layer.use_post_norm:
1847
+ _patch_rms_norm_module(decoder_layer.post_norm1)
1848
+ _patch_rms_norm_module(decoder_layer.post_norm2)
1849
+
1850
+
1851
+ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
1852
+ model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
1853
+ assert model_type == "hyperclovax"
1854
+ apply_fn = apply_liger_kernel_to_hyperclovax
1855
+ apply_fn_signature = inspect.signature(apply_fn)
1856
+
1857
+ # Filter out the keyword arguments that are not supported by the apply function
1858
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
1859
+ logger.info(
1860
+ f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
1861
+ )
1862
+ apply_fn(model=model, **applicable_kwargs)
1863
+
1864
+
1865
+ ################################################################################################
1866
+ ################################################################################################
modeling_vlm.py ADDED
The diff for this file is too large to render. See raw diff
 
patch_vuvlm.py ADDED
@@ -0,0 +1,1085 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import gc
3
+ import inspect
4
+ import json
5
+ import os
6
+ import time
7
+ from functools import partial
8
+ from pathlib import Path
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn as nn
14
+ from liger_kernel.transformers import (
15
+ LigerCrossEntropyLoss,
16
+ LigerFusedLinearCrossEntropyLoss,
17
+ )
18
+ from torch.nn import CrossEntropyLoss
19
+ from transformers import AutoTokenizer
20
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+ from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
23
+
24
+ from hcxvlm.models.ulysses.sp_utils import (
25
+ gather_outputs_and_unpad,
26
+ get_ulysses_sequence_parallel_group,
27
+ get_ulysses_sequence_parallel_rank,
28
+ get_ulysses_sequence_parallel_world_size,
29
+ slice_input_tensor,
30
+ )
31
+
32
+ from .configuration_vlm import HCXVisionConfig
33
+ from .modeling_vlm import HCXVisionForCausalLM, get_rank
34
+
35
+ extra_special_tokens = {
36
+ "image_token": "<|IMAGE_PAD|>",
37
+ "discrete_image_token": "<|DISCRETE_IMAGE_PAD|>",
38
+ "discrete_image_unit_0_id": "<|vision00000|>",
39
+ "video_token": "<|VIDEO_PAD|>",
40
+ "video_audio_token": "<|VIDEO_AUDIO_PAD|>",
41
+ "audio_token": "<|AUDIO_PAD|>",
42
+ "discrete_audio_token": "<|DISCRETE_AUDIO_PAD|>",
43
+ "discrete_audio_unit_0_id": "<|audio0000|>",
44
+ }
45
+
46
+
47
+ def load_state_dict_into_model(model_to_load, state_dict, strict=True, start_prefix=""):
48
+ old_keys = []
49
+ new_keys = []
50
+ for key in state_dict.keys():
51
+ new_key = None
52
+ if "gamma" in key:
53
+ new_key = key.replace("gamma", "weight")
54
+ if "beta" in key:
55
+ new_key = key.replace("beta", "bias")
56
+ if new_key:
57
+ old_keys.append(key)
58
+ new_keys.append(new_key)
59
+ for old_key, new_key in zip(old_keys, new_keys):
60
+ state_dict[new_key] = state_dict.pop(old_key)
61
+
62
+ metadata = getattr(state_dict, "_metadata", None)
63
+ state_dict = state_dict.copy()
64
+ if metadata is not None:
65
+ state_dict._metadata = metadata
66
+
67
+ error_msgs = []
68
+
69
+ def load(module: nn.Module, state_dict, prefix=""):
70
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
71
+ args = (state_dict, prefix, local_metadata, strict, [], [], error_msgs)
72
+ if len([key for key in state_dict if key.startswith(prefix)]) > 0:
73
+ if is_deepspeed_zero3_enabled():
74
+ import deepspeed
75
+
76
+ named_parameters = dict(
77
+ module.named_parameters(prefix=prefix[:-1], recurse=False)
78
+ )
79
+ params_to_gather = [
80
+ named_parameters[k]
81
+ for k in state_dict.keys()
82
+ if k in named_parameters
83
+ ]
84
+ if len(params_to_gather) > 0:
85
+ with deepspeed.zero.GatheredParameters(
86
+ params_to_gather, modifier_rank=0
87
+ ):
88
+ if torch.distributed.get_rank() == 0:
89
+ module._load_from_state_dict(*args)
90
+ else:
91
+ module._load_from_state_dict(*args)
92
+
93
+ for name, child in module._modules.items():
94
+ if child is not None:
95
+ load(child, state_dict, prefix + name + ".")
96
+
97
+ load(model_to_load, state_dict, prefix=start_prefix)
98
+ del state_dict
99
+
100
+ return error_msgs
101
+
102
+
103
+ def load_sharded_checkpoint(
104
+ model,
105
+ folder,
106
+ pick_prefix="",
107
+ replace_prefix_list=[],
108
+ replace_prefix_dict={},
109
+ print_info=True,
110
+ ):
111
+ if folder is None:
112
+ return {}
113
+
114
+ files = os.listdir(folder)
115
+
116
+ pytorch_bin_files = [
117
+ file
118
+ for file in files
119
+ if file.startswith("pytorch_model") and file.endswith(".bin")
120
+ ]
121
+ safetensor_files = [file for file in files if file.endswith(".safetensors")]
122
+ shard_index_file = [file for file in files if file.endswith(".index.json")]
123
+
124
+ index_present = len(shard_index_file) > 0
125
+ index_file = os.path.join(folder, shard_index_file[0]) if index_present else []
126
+
127
+ is_safetensor = len(safetensor_files) > 0
128
+
129
+ model_keys = model.state_dict().keys()
130
+
131
+ if is_safetensor:
132
+ from safetensors.torch import load_file
133
+
134
+ load_function = load_file
135
+ shard_files = safetensor_files
136
+ else:
137
+ load_function = partial(torch.load, map_location="cpu")
138
+ shard_files = pytorch_bin_files
139
+
140
+ if index_present:
141
+ with open(index_file, "r", encoding="utf-8") as f:
142
+ index = json.load(f)
143
+ loaded_keys = index["weight_map"].keys()
144
+ if pick_prefix:
145
+ loaded_keys = [
146
+ k[len(pick_prefix) :] for k in loaded_keys if k.startswith(pick_prefix)
147
+ ]
148
+ if replace_prefix_list:
149
+ for rep_prefix in replace_prefix_list:
150
+ loaded_keys = [
151
+ k[len(rep_prefix) :] if k.startswith(rep_prefix) else k
152
+ for k in loaded_keys
153
+ ]
154
+ if replace_prefix_dict:
155
+ for rep_prefix in replace_prefix_dict:
156
+ loaded_keys = [
157
+ (
158
+ k.replace(rep_prefix, replace_prefix_dict[rep_prefix])
159
+ if k.startswith(rep_prefix)
160
+ else k
161
+ )
162
+ for k in loaded_keys
163
+ ]
164
+
165
+ for i, shard_file in enumerate(shard_files):
166
+ state_dict = load_function(os.path.join(folder, shard_file))
167
+
168
+ if pick_prefix:
169
+ state_dict = {
170
+ k[len(pick_prefix) :]: v
171
+ for k, v in state_dict.items()
172
+ if k.startswith(pick_prefix)
173
+ }
174
+
175
+ for rep_prefix in replace_prefix_list:
176
+ state_dict = {
177
+ k[len(rep_prefix) :] if k.startswith(rep_prefix) else k: v
178
+ for k, v in state_dict.items()
179
+ }
180
+
181
+ for rep_prefix in replace_prefix_dict:
182
+ state_dict = {
183
+ (
184
+ k.replace(rep_prefix, replace_prefix_dict[rep_prefix])
185
+ if k.startswith(rep_prefix)
186
+ else k
187
+ ): v
188
+ for k, v in state_dict.items()
189
+ }
190
+
191
+ if is_deepspeed_zero3_enabled():
192
+ rank = torch.distributed.get_rank()
193
+ print(f"# [info] ZeRo3 - load sharded no {i}, rank {rank}")
194
+ load_state_dict_into_model(model, state_dict, strict=False)
195
+ elif is_fsdp_enabled():
196
+ if is_local_dist_rank_0():
197
+ model.load_state_dict(state_dict, strict=False)
198
+ else:
199
+ model.load_state_dict(state_dict, strict=False)
200
+
201
+ if not index_present:
202
+ loaded_keys = state_dict.keys()
203
+
204
+ del state_dict
205
+ gc.collect()
206
+
207
+ missing_keys = [key for key in model_keys if key not in loaded_keys]
208
+ unexpected_keys = [key for key in loaded_keys if key not in model_keys]
209
+
210
+ if get_rank() == 0 and print_info:
211
+ print(f"[info] missing_keys: {missing_keys}")
212
+ print(f"[info] unexpected_keys: {unexpected_keys}")
213
+
214
+ return {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
215
+
216
+
217
+ class HCXVisionForCausalLM_VU(HCXVisionForCausalLM):
218
+ def __init__(self, config, **kwargs):
219
+ self.use_liger = kwargs.pop("use_liger", True)
220
+ self.use_fused_ce = kwargs.pop("use_fused_ce", True)
221
+ self.use_meansum_loss = kwargs.pop("use_meansum_loss", True)
222
+ self.use_turnmeansum_loss = kwargs.pop("use_turnmeansum_loss", False)
223
+ self.use_sqrtsum_loss = kwargs.pop("use_sqrtsum_loss", False)
224
+ use_sum_loss = True if kwargs.pop("use_sum_loss", False) else False
225
+
226
+ self.sequence_parallel_size = kwargs.pop("sequence_parallel_size", 1)
227
+ self.sp_manager = kwargs.pop("sp_manager", None)
228
+ self.train_video = kwargs.pop("train_video", False)
229
+
230
+ assert (
231
+ int(self.use_meansum_loss)
232
+ + int(self.use_turnmeansum_loss)
233
+ + int(self.use_sqrtsum_loss)
234
+ ) <= 1, "use_meansum_loss, use_turnmeansum_loss, use_sqrtsum_loss 중 둘 이상을 동시에 True로 설정할 수 없습니다."
235
+
236
+ if self.use_meansum_loss or self.use_turnmeansum_loss or self.use_sqrtsum_loss:
237
+ self.reduction = "none"
238
+ elif use_sum_loss:
239
+ self.reduction = "sum"
240
+ else:
241
+ self.reduction = "mean"
242
+
243
+ super().__init__(config, **kwargs)
244
+ if config.text_config.model_type == "hyperclovax" and self.use_liger:
245
+ self.language_model._get_apply_liger_kernel_converter()(
246
+ model=self.language_model
247
+ )
248
+ print("[info] use liger kernel for hcx 24b")
249
+ if config.freeze_encoder:
250
+ for param in self.vision_model.parameters():
251
+ param.requires_grad = False
252
+ assert (
253
+ all(param.requires_grad for param in self.vision_model.parameters())
254
+ == False
255
+ )
256
+
257
+ @classmethod
258
+ def from_pretrained(
259
+ cls,
260
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
261
+ text_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
262
+ vision_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
263
+ discrete_vision_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
264
+ audio_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
265
+ discrete_audio_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
266
+ q_former_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
267
+ without_llm: bool = False,
268
+ *model_args,
269
+ **kwargs,
270
+ ):
271
+ """
272
+ :param pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for LLM(text_model_name_or_path) e.g. /path/to/model/
273
+ :param vision_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for VisionModule(HyperClova-VisionModule) e.g. /path/to/vision/module/
274
+ :param q_former_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for VLM e.g. /path/to/vlm/checkpoint/
275
+ :param without_llm: Bool: False: init/load llm weight from pre-trained True: init/load llm weight from dummy file
276
+ :param model_args:
277
+ :param kwargs:
278
+ :return:
279
+ """
280
+ assert pretrained_model_name_or_path is not None or (
281
+ text_model_name_or_path is not None
282
+ and vision_model_name_or_path is not None
283
+ )
284
+
285
+ cache_dirpath = kwargs.pop("cache_dirpath", None)
286
+ if cache_dirpath is None:
287
+ cache_dirpath = "~/.cache"
288
+
289
+ runtime_only_keys = {
290
+ "use_liger",
291
+ "use_fused_ce",
292
+ "use_meansum_loss",
293
+ "use_turnmeansum_loss",
294
+ "use_sqrtsum_loss",
295
+ "use_sum_loss",
296
+ "sequence_parallel_size",
297
+ "sp_manager",
298
+ "train_video",
299
+ }
300
+ runtime_kwargs = {}
301
+ for k in list(runtime_only_keys):
302
+ if k in kwargs:
303
+ runtime_kwargs[k] = kwargs.pop(k)
304
+
305
+ kwargs["vision_model_name_or_path"] = vision_model_name_or_path
306
+ kwargs["discrete_vision_model_name_or_path"] = (
307
+ discrete_vision_model_name_or_path
308
+ )
309
+ kwargs["audio_model_name_or_path"] = audio_model_name_or_path
310
+ kwargs["discrete_audio_model_name_or_path"] = discrete_audio_model_name_or_path
311
+
312
+ save_only_vision = (
313
+ kwargs.pop("save_only_vision") if "save_only_vision" in kwargs else False
314
+ )
315
+ save_only_qformer = (
316
+ kwargs.pop("save_only_qformer") if "save_only_qformer" in kwargs else False
317
+ )
318
+ save_shard_size = (
319
+ kwargs.pop("save_shard_size") if "save_shard_size" in kwargs else "5GB"
320
+ )
321
+
322
+ def _purge_runtime_from_config(cfg):
323
+ for rk in runtime_only_keys:
324
+ if hasattr(cfg, rk):
325
+ delattr(cfg, rk)
326
+
327
+ template_path = "hcxvlm/dataset/chat_template.jinja"
328
+ with open(template_path, "r", encoding="utf-8") as f:
329
+ chat_template_str = f.read()
330
+ if without_llm:
331
+ assert pretrained_model_name_or_path is not None and os.path.exists(
332
+ pretrained_model_name_or_path
333
+ )
334
+
335
+ dummy_config = HCXVisionConfig.from_pretrained(
336
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
337
+ *model_args,
338
+ **kwargs,
339
+ )
340
+ _purge_runtime_from_config(dummy_config)
341
+ dummy_config.text_config.num_hidden_layers = 0
342
+ dummy_config.text_config.num_attention_heads = 1
343
+
344
+ if isinstance(
345
+ dummy_config.vision_model_name_or_path, str
346
+ ) and os.path.exists(dummy_config.vision_model_name_or_path):
347
+ vision_model_name_or_path = dummy_config.vision_model_name_or_path
348
+ assert isinstance(vision_model_name_or_path, str) and os.path.exists(
349
+ vision_model_name_or_path
350
+ ), f"# [error] invalid vision_model_name_or_path: {vision_model_name_or_path}"
351
+ dummy_config.vision_model_name_or_path = vision_model_name_or_path
352
+ dummy_config.vision_config._name_or_path = vision_model_name_or_path
353
+ dummy_config.vision_config.vison_pretrained_name_or_path = (
354
+ vision_model_name_or_path
355
+ )
356
+
357
+ model = super().from_pretrained(
358
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
359
+ without_llm=True,
360
+ config=dummy_config,
361
+ *model_args,
362
+ **{**kwargs, **runtime_kwargs},
363
+ )
364
+ model.tokenizer = AutoTokenizer.from_pretrained(
365
+ pretrained_model_name_or_path
366
+ )
367
+ model.tokenizer.chat_template = chat_template_str
368
+ model.transformer = None
369
+ else:
370
+ if pretrained_model_name_or_path is not None and (
371
+ audio_model_name_or_path is not None
372
+ or discrete_audio_model_name_or_path is not None
373
+ or discrete_vision_model_name_or_path is not None
374
+ ):
375
+ assert (
376
+ audio_model_name_or_path is not None
377
+ and discrete_audio_model_name_or_path is not None
378
+ and discrete_vision_model_name_or_path is not None
379
+ )
380
+ print(f"[DEBUG] image stage2 끝난 시점에서 audio 를 stage3 로 붙일때.")
381
+ pt_config = HCXVisionConfig.from_pretrained(
382
+ pretrained_model_name_or_path
383
+ )
384
+ _purge_runtime_from_config(pt_config)
385
+ config_dict = pt_config.to_dict()
386
+ config_dict["audio_model_name_or_path"] = audio_model_name_or_path
387
+ config_dict["discrete_audio_model_name_or_path"] = (
388
+ discrete_audio_model_name_or_path
389
+ )
390
+ config_dict["discrete_vision_model_name_or_path"] = (
391
+ discrete_vision_model_name_or_path
392
+ )
393
+ config = HCXVisionConfig.from_dict(config_dict)
394
+ print(f"config: {config}")
395
+ model = super().from_pretrained(
396
+ pretrained_model_name_or_path,
397
+ without_llm=False,
398
+ config=config,
399
+ _fast_init=False,
400
+ *model_args,
401
+ **kwargs,
402
+ )
403
+ model.tokenizer = AutoTokenizer.from_pretrained(
404
+ pretrained_model_name_or_path
405
+ )
406
+ model.tokenizer.chat_template = chat_template_str
407
+ elif isinstance(q_former_model_name_or_path, str):
408
+ config = HCXVisionConfig.from_dict(
409
+ {"text_model_name_or_path": text_model_name_or_path, **kwargs}
410
+ )
411
+ _purge_runtime_from_config(config)
412
+ model = super().from_pretrained(
413
+ q_former_model_name_or_path,
414
+ without_llm=False,
415
+ config=config,
416
+ _fast_init=False,
417
+ *model_args,
418
+ **{**kwargs, **runtime_kwargs},
419
+ )
420
+ model.tokenizer = AutoTokenizer.from_pretrained(
421
+ q_former_model_name_or_path
422
+ )
423
+ model.tokenizer.chat_template = chat_template_str
424
+ elif pretrained_model_name_or_path is not None:
425
+ config = HCXVisionConfig.from_pretrained(
426
+ pretrained_model_name_or_path, *model_args, **kwargs
427
+ )
428
+ _purge_runtime_from_config(config)
429
+ model = super().from_pretrained(
430
+ pretrained_model_name_or_path,
431
+ *model_args,
432
+ config=config,
433
+ **runtime_kwargs,
434
+ )
435
+ model.tokenizer = AutoTokenizer.from_pretrained(
436
+ pretrained_model_name_or_path
437
+ )
438
+ model.tokenizer.chat_template = chat_template_str
439
+ else:
440
+ config = HCXVisionConfig.from_dict(
441
+ {"text_model_name_or_path": text_model_name_or_path, **kwargs}
442
+ )
443
+ _purge_runtime_from_config(config)
444
+ model = HCXVisionForCausalLM_VU(
445
+ config, *model_args, **{**kwargs, **runtime_kwargs}
446
+ )
447
+ model.tokenizer = AutoTokenizer.from_pretrained(text_model_name_or_path)
448
+ model.tokenizer.chat_template = chat_template_str
449
+ model.mm_projector.apply(model._init_weights)
450
+
451
+ img_start_id = model.tokenizer.encode(
452
+ extra_special_tokens["image_token"], add_special_tokens=False
453
+ )
454
+ assert (
455
+ len(img_start_id) == 1
456
+ ), f'{extra_special_tokens["image_token"]} was not encoded into a single special token. Encoding result: {img_start_id}'
457
+ model.config.img_start_id = img_start_id[0]
458
+ model.config.image_token_id = img_start_id[0]
459
+
460
+ video_start_id = model.tokenizer.encode(
461
+ extra_special_tokens["video_token"], add_special_tokens=False
462
+ )
463
+ assert (
464
+ len(video_start_id) == 1
465
+ ), f"video_token was not encoded into a single special token. Encoding result: {video_start_id}"
466
+ model.config.video_start_id = video_start_id[0]
467
+ model.config.video_token_id = video_start_id[0]
468
+
469
+ video_audio_start_id = model.tokenizer.encode(
470
+ extra_special_tokens["video_audio_token"], add_special_tokens=False
471
+ )
472
+ assert (
473
+ len(video_audio_start_id) == 1
474
+ ), f"video_audio_token was not encoded into a single special token. Encoding result: {video_audio_start_id}"
475
+ model.config.video_audio_start_id = video_audio_start_id[0]
476
+ model.config.video_audio_token_id = video_audio_start_id[0]
477
+
478
+ if (
479
+ audio_model_name_or_path is not None
480
+ or discrete_audio_model_name_or_path is not None
481
+ or discrete_vision_model_name_or_path is not None
482
+ ):
483
+ audio_start_id = model.tokenizer.encode(
484
+ extra_special_tokens["audio_token"], add_special_tokens=False
485
+ )
486
+ assert (
487
+ len(audio_start_id) == 1
488
+ ), f"audio_token was not encoded into a single special token. Encoding result: {audio_start_id}"
489
+ model.config.audio_start_id = audio_start_id[0]
490
+ model.config.audio_token_id = audio_start_id[0]
491
+
492
+ discrete_audio_start_id = model.tokenizer.encode(
493
+ extra_special_tokens["discrete_audio_token"], add_special_tokens=False
494
+ )
495
+ assert (
496
+ len(discrete_audio_start_id) == 1
497
+ ), f"discrete_audio_token was not encoded into a single special token. Encoding result: {discrete_audio_start_id}"
498
+ model.config.discrete_audio_start_id = discrete_audio_start_id[0]
499
+ model.config.discrete_audio_token_id = discrete_audio_start_id[0]
500
+ discrete_audio_unit_0_id = model.tokenizer.encode(
501
+ extra_special_tokens["discrete_audio_unit_0_id"],
502
+ add_special_tokens=False,
503
+ )
504
+ assert (
505
+ len(discrete_audio_unit_0_id) == 1
506
+ ), f'{extra_special_tokens["discrete_audio_unit_0_id"]} was not encoded into a single special token. Encoding result: {discrete_audio_unit_0_id}'
507
+ model.config.discrete_audio_unit_0_id = discrete_audio_unit_0_id[0]
508
+
509
+ discrete_image_start_id = model.tokenizer.encode(
510
+ extra_special_tokens["discrete_image_token"], add_special_tokens=False
511
+ )
512
+ assert (
513
+ len(discrete_image_start_id) == 1
514
+ ), f'{extra_special_tokens["discrete_image_token"]} was not encoded into a single special token. Encoding result: {discrete_image_start_id}'
515
+ model.config.discrete_image_start_id = discrete_image_start_id[0]
516
+ model.config.discrete_image_token_id = discrete_image_start_id[0]
517
+ discrete_image_unit_0_id = model.tokenizer.encode(
518
+ extra_special_tokens["discrete_image_unit_0_id"],
519
+ add_special_tokens=False,
520
+ )
521
+ assert (
522
+ len(discrete_image_unit_0_id) == 1
523
+ ), f'{extra_special_tokens["discrete_image_unit_0_id"]} was not encoded into a single special token. Encoding result: {discrete_image_unit_0_id}'
524
+ model.config.discrete_image_unit_0_id = discrete_image_unit_0_id[0]
525
+
526
+ model.save_only_vision = save_only_vision
527
+ model.save_only_qformer = save_only_qformer
528
+ model.save_shard_size = save_shard_size
529
+
530
+ if pretrained_model_name_or_path is None or (
531
+ pretrained_model_name_or_path is not None
532
+ and audio_model_name_or_path is not None
533
+ ):
534
+ vision_model_name_or_path = kwargs.get("vision_model_name_or_path", None)
535
+ if vision_model_name_or_path is not None:
536
+ load_sharded_checkpoint(model.vision_model, vision_model_name_or_path)
537
+ if get_rank() == 0:
538
+ print("[info] vision model loading complete")
539
+
540
+ discrete_vision_model_name_or_path = kwargs.get(
541
+ "discrete_vision_model_name_or_path", None
542
+ )
543
+ if discrete_vision_model_name_or_path is not None:
544
+
545
+ model.discrete_vision_model.load_state_dict(
546
+ torch.load(
547
+ discrete_vision_model_name_or_path,
548
+ map_location=model.device,
549
+ weights_only=False,
550
+ )["model"]["sd"],
551
+ strict=True,
552
+ )
553
+ if get_rank() == 0:
554
+ print("[info] discrete vision model loading complete")
555
+
556
+ audio_model_name_or_path = kwargs.get("audio_model_name_or_path", None)
557
+ if audio_model_name_or_path is not None:
558
+ load_sharded_checkpoint(model.audio_model, audio_model_name_or_path)
559
+ if get_rank() == 0:
560
+ print("[info] audio model loading complete")
561
+
562
+ discrete_audio_model_name_or_path = kwargs.get(
563
+ "discrete_audio_model_name_or_path", None
564
+ )
565
+ if discrete_audio_model_name_or_path is not None:
566
+
567
+ model.discrete_audio_model.load_state_dict(
568
+ torch.load(
569
+ discrete_audio_model_name_or_path,
570
+ map_location=model.device,
571
+ weights_only=False,
572
+ ),
573
+ strict=True,
574
+ )
575
+ if get_rank() == 0:
576
+ print("[info] discrete audio model loading complete")
577
+
578
+ if text_model_name_or_path is not None:
579
+ load_sharded_checkpoint(model.language_model, text_model_name_or_path)
580
+ if get_rank() == 0:
581
+ print("[info] text model loading complete")
582
+
583
+ if isinstance(q_former_model_name_or_path, str):
584
+ assert Path(
585
+ q_former_model_name_or_path
586
+ ).exists(), f"# [error] given q_former_name_or_path not exist: {q_former_model_name_or_path}"
587
+
588
+ load_result = load_sharded_checkpoint(
589
+ model,
590
+ q_former_model_name_or_path,
591
+ replace_prefix_dict={
592
+ "vision_model.image_encoder.model.vision_tower": "vision_model",
593
+ "model": "language_model.model",
594
+ "lm_head.weight": "language_model.lm_head.weight",
595
+ },
596
+ print_info=False,
597
+ )
598
+
599
+ if get_rank() == 0:
600
+ missing_keys_summary = dict()
601
+ for key in load_result["missing_keys"]:
602
+ if key.split(".")[0] in missing_keys_summary:
603
+ missing_keys_summary[key.split(".")[0]] += 1
604
+ else:
605
+ missing_keys_summary[key.split(".")[0]] = 1
606
+ print(f"[info] missing_keys summary : {missing_keys_summary}")
607
+ print("[info] q_former model loading complete")
608
+
609
+ config: HCXVisionConfig = model.config
610
+ if config.model_type != "vlm":
611
+ model.config.model_type = "vlm"
612
+
613
+ return model
614
+
615
+ def _pad_sequence_for_sp(
616
+ self,
617
+ inputs_embeds: torch.Tensor,
618
+ labels: Optional[torch.Tensor],
619
+ sp_world_size: int,
620
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
621
+ """
622
+ Ensure sequence length is divisible by the SP group size by padding on the sequence dimension.
623
+ Returns the possibly padded (inputs_embeds, labels).
624
+ """
625
+ batch_size, seqlen, hidden_size = inputs_embeds.shape
626
+ remainder = seqlen % sp_world_size
627
+ if remainder != 0:
628
+ print(
629
+ f"[info] Padding sequence dimension to make it divisible by {sp_world_size}"
630
+ )
631
+ pad_len = sp_world_size - remainder
632
+ pad_embeds = torch.zeros(
633
+ (batch_size, pad_len, hidden_size),
634
+ dtype=inputs_embeds.dtype,
635
+ device=inputs_embeds.device,
636
+ )
637
+ inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1)
638
+
639
+ if labels is not None:
640
+ ignore_index = getattr(self.config, "ignore_index", -100)
641
+ pad_labels = torch.full(
642
+ (batch_size, pad_len),
643
+ fill_value=ignore_index,
644
+ dtype=labels.dtype,
645
+ device=labels.device,
646
+ )
647
+ labels = torch.cat([labels, pad_labels], dim=1)
648
+
649
+ return inputs_embeds, labels
650
+
651
+ def forward(
652
+ self,
653
+ input_ids: Optional[torch.LongTensor] = None,
654
+ pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
655
+ discrete_pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
656
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
657
+ attention_mask: Optional[torch.FloatTensor] = None,
658
+ position_ids: Optional[torch.LongTensor] = None,
659
+ inputs_embeds: Optional[torch.FloatTensor] = None,
660
+ labels: Optional[torch.LongTensor] = None,
661
+ use_cache: Optional[bool] = None,
662
+ output_attentions: Optional[bool] = None,
663
+ output_hidden_states: Optional[bool] = None,
664
+ return_dict: Optional[bool] = None,
665
+ image_sizes: Optional[List[List[List[int]]]] = None,
666
+ mm_query_lengths: Optional[List[List[int]]] = None,
667
+ non_mm_query_lengths: Optional[List[List[int]]] = None,
668
+ img_start_ids_list: Optional[List[List[int]]] = None,
669
+ num_queries_vis_abstractors: Optional[List[List[int]]] = None,
670
+ num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
671
+ first_last_frames_slows: Optional[List[List[bool]]] = None,
672
+ is_videos: Optional[List[List[bool]]] = None,
673
+ image_grid_thw: Optional[torch.LongTensor] = None,
674
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
675
+ video_grid_thw: Optional[torch.LongTensor] = None,
676
+ video_audio_values: Optional[torch.FloatTensor] = None,
677
+ video_audio_masks: Optional[torch.FloatTensor] = None,
678
+ audio_values: Optional[torch.FloatTensor] = None,
679
+ discrete_audio_values: Optional[torch.FloatTensor] = None,
680
+ discrete_audio_value_num_per_sample: Optional[torch.LongTensor] = None,
681
+ audio_masks: Optional[torch.LongTensor] = None,
682
+ **kwargs,
683
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
684
+ """
685
+ :param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer.
686
+ In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data.
687
+ :param pixel_values: List of List of 4D tensor (torch.float32)
688
+ Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images.
689
+ :param past_key_values: None
690
+ :param inputs_embeds: None
691
+ :param labels: Optional[torch.int64] : [batchsize, variable (input_ids.size(1)+ num visual tokens)] visual token 들은 모두 IGNORE_INDEX
692
+ :param use_cache: None
693
+ :param output_attentions: Optional[bool] : get attention weights of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
694
+ :param output_hidden_states: Optional[bool] : get hidden states of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
695
+ :param return_dict: Optional[bool] : True - return dict, Fasle - return tensor
696
+ :param image_sizes: Stacked as a List of List, representing image sizes (width, height).
697
+ In cases where a sample contains no images, a single dummy image is included.
698
+ :param mm_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input.
699
+ In cases where a sample does not contain any images, an empty list is included.
700
+ :param non_mm_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch.
701
+ :img_start_ids_list: contains the indices of the img_start_id tokens for each sample.
702
+ :num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid.
703
+ :num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None.
704
+ :first_last_frames_slows: A List of List that contains the only first and last frames slow mode for each sample in a batch.
705
+ :is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video.
706
+ :image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
707
+ :pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder.
708
+ :video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
709
+ :return:
710
+ """
711
+
712
+ if self.sp_manager is not None and self.train_video:
713
+ sp_group = get_ulysses_sequence_parallel_group()
714
+ if sp_group is not None:
715
+ sp_rank = get_ulysses_sequence_parallel_rank(sp_group)
716
+ sp_world_size = get_ulysses_sequence_parallel_world_size(sp_group)
717
+ if sp_rank == 0:
718
+ payload = {
719
+ "input_ids": input_ids,
720
+ "labels": labels,
721
+ "pixel_values": pixel_values,
722
+ "image_grid_thw": image_grid_thw,
723
+ "pixel_values_videos": pixel_values_videos,
724
+ "video_grid_thw": video_grid_thw,
725
+ "video_audio_values": video_audio_values,
726
+ "video_audio_masks": video_audio_masks,
727
+ }
728
+ else:
729
+ payload = {
730
+ "input_ids": None,
731
+ "labels": None,
732
+ "pixel_values": None,
733
+ "image_grid_thw": None,
734
+ "pixel_values_videos": None,
735
+ "video_grid_thw": None,
736
+ "video_audio_values": None,
737
+ "video_audio_masks": None,
738
+ }
739
+
740
+ obj_list = [payload]
741
+ src_global_rank = dist.get_global_rank(sp_group, 0)
742
+ dist.broadcast_object_list(
743
+ obj_list, src=src_global_rank, group=sp_group
744
+ )
745
+ payload = obj_list[0]
746
+
747
+ if sp_rank != 0:
748
+ device = input_ids.device
749
+
750
+ input_ids = payload["input_ids"]
751
+ if isinstance(input_ids, torch.Tensor):
752
+ input_ids = input_ids.to(device)
753
+
754
+ labels = payload["labels"]
755
+ if isinstance(labels, torch.Tensor):
756
+ labels = labels.to(device)
757
+
758
+ image_grid_thw = payload["image_grid_thw"]
759
+ if isinstance(image_grid_thw, torch.Tensor):
760
+ image_grid_thw = image_grid_thw.to(device)
761
+
762
+ pixel_values_videos = payload["pixel_values_videos"]
763
+ if isinstance(pixel_values_videos, torch.Tensor):
764
+ pixel_values_videos = pixel_values_videos.to(device)
765
+
766
+ video_grid_thw = payload["video_grid_thw"]
767
+ if isinstance(video_grid_thw, torch.Tensor):
768
+ video_grid_thw = video_grid_thw.to(device)
769
+
770
+ video_audio_values = payload["video_audio_values"]
771
+ if isinstance(video_audio_values, torch.Tensor):
772
+ video_audio_values = video_audio_values.to(device)
773
+
774
+ video_audio_masks = payload["video_audio_masks"]
775
+ if isinstance(video_audio_masks, torch.Tensor):
776
+ video_audio_masks = video_audio_masks.to(device)
777
+
778
+ pixel_values = payload["pixel_values"]
779
+ if isinstance(pixel_values, torch.Tensor):
780
+ pixel_values = pixel_values.to(device)
781
+
782
+ attention_mask = None
783
+ output_attentions = (
784
+ output_attentions
785
+ if output_attentions is not None
786
+ else self.config.vision_config.output_attentions
787
+ )
788
+ output_hidden_states = (
789
+ output_hidden_states
790
+ if output_hidden_states is not None
791
+ else self.config.vision_config.output_hidden_states
792
+ )
793
+ return_dict = (
794
+ return_dict if return_dict is not None else self.config.use_return_dict
795
+ )
796
+
797
+ if inputs_embeds is None and past_key_values is None:
798
+ inputs_embeds, labels = self.model.extract_inputs_embeds(
799
+ input_ids=input_ids,
800
+ labels=labels,
801
+ pixel_values=pixel_values,
802
+ discrete_pixel_values=discrete_pixel_values,
803
+ past_key_values=past_key_values,
804
+ image_sizes=image_sizes,
805
+ mm_query_lengths=mm_query_lengths,
806
+ non_mm_query_lengths=non_mm_query_lengths,
807
+ img_start_ids_list=img_start_ids_list,
808
+ num_queries_vis_abstractors=num_queries_vis_abstractors,
809
+ num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
810
+ first_last_frames_slows=first_last_frames_slows,
811
+ is_videos=is_videos,
812
+ image_grid_thw=image_grid_thw,
813
+ pixel_values_videos=pixel_values_videos,
814
+ video_grid_thw=video_grid_thw,
815
+ video_audio_values=video_audio_values,
816
+ video_audio_masks=video_audio_masks,
817
+ audio_values=audio_values,
818
+ discrete_audio_values=discrete_audio_values,
819
+ discrete_audio_value_num_per_sample=discrete_audio_value_num_per_sample,
820
+ audio_masks=audio_masks,
821
+ )
822
+
823
+ if labels is not None and labels.size(1) > 32768:
824
+ print(
825
+ f"[RANK {rank} debug] ❌ labels.size(1) > 32768. labels.size(): {labels.size()}"
826
+ )
827
+
828
+ if inputs_embeds is not None:
829
+ input_ids = None
830
+
831
+ import os
832
+
833
+ rank = int(os.environ.get("RANK", -1))
834
+
835
+ if inputs_embeds is not None:
836
+ expected_hidden_size = self.config.text_config.hidden_size
837
+ if inputs_embeds.shape[-1] != expected_hidden_size:
838
+ print(f"[RANK {rank}] ❌ inputs_embeds dimension mismatch!")
839
+ print(
840
+ f" Expected: {expected_hidden_size}, Got: {inputs_embeds.shape[-1]}"
841
+ )
842
+
843
+ if labels is not None:
844
+ vocab_size = self.get_input_embeddings().num_embeddings
845
+ valid_labels = labels[labels != -100]
846
+ if len(valid_labels) > 0:
847
+ if (valid_labels >= vocab_size).any() or (valid_labels < 0).any():
848
+ print(f"[RANK {rank}] ❌ CRITICAL: labels out of vocab range!")
849
+ print(
850
+ f" labels min/max: {valid_labels.min().item()}/{valid_labels.max().item()}"
851
+ )
852
+ print(f" vocab_size: {vocab_size}")
853
+ print(
854
+ f" Out-of-range count: {(valid_labels >= vocab_size).sum().item()}"
855
+ )
856
+
857
+ if attention_mask is not None and inputs_embeds is not None:
858
+ if attention_mask.shape[1] != inputs_embeds.shape[1]:
859
+ print(f"[RANK {rank}] ❌ attention_mask shape mismatch!")
860
+ print(
861
+ f" attention_mask: {attention_mask.shape}, inputs_embeds: {inputs_embeds.shape}"
862
+ )
863
+
864
+ if position_ids is not None:
865
+ max_position = position_ids.max().item()
866
+ if hasattr(self.language_model.config, "max_position_embeddings"):
867
+ max_allowed = self.language_model.config.max_position_embeddings
868
+ if max_position >= max_allowed:
869
+ print(f"[RANK {rank}] ❌ position_ids out of range!")
870
+ print(f" max_position: {max_position}, max_allowed: {max_allowed}")
871
+
872
+ if self.sp_manager is not None:
873
+
874
+ batch_size, seqlen, hidden_size = inputs_embeds.shape
875
+
876
+ sp_group = get_ulysses_sequence_parallel_group()
877
+ sp_world_size = get_ulysses_sequence_parallel_world_size(sp_group)
878
+
879
+ inputs_embeds, labels = self._pad_sequence_for_sp(
880
+ inputs_embeds, labels, sp_world_size
881
+ )
882
+
883
+ if position_ids is None:
884
+ position_ids = torch.arange(
885
+ seqlen, device=inputs_embeds.device, dtype=torch.long
886
+ )
887
+ position_ids = (
888
+ position_ids.unsqueeze(0).expand(batch_size, -1).contiguous()
889
+ )
890
+
891
+ inputs_embeds = slice_input_tensor(
892
+ inputs_embeds, 1, padding=False, group=sp_group
893
+ )
894
+ labels = slice_input_tensor(labels, 1, padding=False, group=sp_group)
895
+ use_cache = False
896
+
897
+ outputs = self.language_model.base_model(
898
+ input_ids=input_ids,
899
+ inputs_embeds=inputs_embeds,
900
+ attention_mask=attention_mask,
901
+ position_ids=position_ids,
902
+ past_key_values=past_key_values,
903
+ use_cache=use_cache,
904
+ output_attentions=output_attentions,
905
+ output_hidden_states=output_hidden_states,
906
+ return_dict=return_dict,
907
+ )
908
+
909
+ hidden_states = outputs[0]
910
+ hidden_states = hidden_states * self.config.text_config.logits_scaling
911
+
912
+ loss = None
913
+ logits = None
914
+
915
+ if labels is not None:
916
+ if self.use_liger and self.use_fused_ce:
917
+ shift_labels = labels[..., 1:].contiguous()
918
+ shift_labels = shift_labels.view(-1)
919
+
920
+ hidden_states = hidden_states[..., :-1, :].contiguous()
921
+ hidden_states = hidden_states.view(
922
+ -1, self.language_model.config.hidden_size
923
+ ).to(self.language_model.lm_head.weight.dtype)
924
+
925
+ import os
926
+
927
+ rank = int(os.environ.get("RANK", -1))
928
+
929
+ vocab_size = self.language_model.lm_head.weight.shape[0]
930
+ valid_labels = shift_labels[shift_labels != -100]
931
+ if len(valid_labels) > 0 and (
932
+ (valid_labels >= vocab_size).any() or (valid_labels < 0).any()
933
+ ):
934
+ print(
935
+ f"[RANK {rank}] ❌ CRITICAL: shift_labels out of vocab range!"
936
+ )
937
+ print(
938
+ f" min/max: {valid_labels.min().item()}/{valid_labels.max().item()}, vocab: {vocab_size}"
939
+ )
940
+ print(
941
+ f" Out-of-range count: {(valid_labels >= vocab_size).sum().item()}"
942
+ )
943
+
944
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=self.reduction)
945
+ try:
946
+ loss = lce(
947
+ self.language_model.lm_head.weight, hidden_states, shift_labels
948
+ )
949
+ except RuntimeError as e:
950
+ print(
951
+ f"[RANK {rank}] ❌ FATAL: LigerFusedLinearCrossEntropyLoss failed!"
952
+ )
953
+ print(f" Error: {e}")
954
+ print(
955
+ f" hidden_states: shape={hidden_states.shape}, dtype={hidden_states.dtype}"
956
+ )
957
+ print(
958
+ f" shift_labels: shape={shift_labels.shape}, unique_values={torch.unique(shift_labels).tolist()[:20]}"
959
+ )
960
+ print(
961
+ f" lm_head.weight: shape={self.language_model.lm_head.weight.shape}"
962
+ )
963
+ raise
964
+ elif self.use_liger:
965
+ logits = self.language_model.lm_head(hidden_states)
966
+
967
+ shift_logits = logits[..., :-1, :].contiguous()
968
+ shift_labels = labels[..., 1:].contiguous()
969
+
970
+ loss_fct = LigerCrossEntropyLoss(reduction=self.reduction)
971
+ shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
972
+ shift_labels = shift_labels.view(-1)
973
+ shift_labels = shift_labels.to(shift_logits.device)
974
+ loss = loss_fct(shift_logits, shift_labels)
975
+ else:
976
+ logits = self.language_model.lm_head(hidden_states)
977
+
978
+ shift_logits = logits[..., :-1, :].contiguous()
979
+ shift_labels = labels[..., 1:].contiguous()
980
+
981
+ loss_fct = CrossEntropyLoss(reduction=self.reduction)
982
+ shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
983
+ shift_labels = shift_labels.view(-1)
984
+ shift_labels = shift_labels.to(shift_logits.device)
985
+ loss = loss_fct(shift_logits, shift_labels)
986
+
987
+ if self.sp_manager is not None:
988
+ loss = gather_outputs_and_unpad(
989
+ loss, gather_dim=0, unpad_dim=0, padding_size=0, group=sp_group
990
+ )
991
+
992
+ if self.use_meansum_loss:
993
+ loss = loss.view(labels.size(0), -1).mean(dim=1).sum()
994
+
995
+ elif self.use_sqrtsum_loss:
996
+ per_token = loss.view(labels.size(0), -1)
997
+ per_sample_mean = per_token.mean(dim=1)
998
+
999
+ with torch.no_grad():
1000
+ labels_2d = labels.view(labels.size(0), -1)
1001
+ ignore_index = getattr(self.config, "ignore_index", -100)
1002
+ valid_mask = labels_2d.ne(ignore_index)
1003
+ valid_count = valid_mask.sum(dim=1).clamp(min=1).float()
1004
+ raw_w = valid_count.sqrt()
1005
+ w_mean = raw_w.mean().clamp(min=1e-6)
1006
+ norm_w = raw_w / w_mean
1007
+
1008
+ loss = (per_sample_mean * norm_w).sum()
1009
+
1010
+ elif self.use_turnmeansum_loss:
1011
+ with torch.no_grad():
1012
+ mask = shift_labels.view(labels.size(0), -1).ne(
1013
+ self.config.ignore_index
1014
+ )
1015
+ prev_mask = mask.roll(shifts=1, dims=1)
1016
+ prev_mask[:, 0] = False
1017
+
1018
+ turn_starts = mask & (~prev_mask)
1019
+
1020
+ turn_count = turn_starts.sum(dim=1).clamp(min=1).float()
1021
+
1022
+ loss = (loss.view(labels.size(0), -1).mean(dim=1) * turn_count).sum()
1023
+
1024
+ if self.sp_manager is not None:
1025
+ loss = loss / self.sp_manager.device_mesh.shape[1]
1026
+ if not return_dict:
1027
+ output = (logits,) + outputs[1:]
1028
+ return (loss,) + output if loss is not None else output
1029
+
1030
+ return CausalLMOutputWithPast(
1031
+ loss=loss,
1032
+ logits=logits,
1033
+ past_key_values=outputs.past_key_values,
1034
+ hidden_states=outputs.hidden_states,
1035
+ attentions=outputs.attentions,
1036
+ )
1037
+
1038
+ def save_pretrained(
1039
+ self,
1040
+ save_directory: Union[str, os.PathLike],
1041
+ *args,
1042
+ **kwargs,
1043
+ ):
1044
+
1045
+ state_dict = (
1046
+ kwargs["state_dict"]
1047
+ if kwargs.get("state_dict", None)
1048
+ else self.state_dict()
1049
+ )
1050
+ partial_state_dict = self.get_pretrained_state_dict(
1051
+ state_dict,
1052
+ )
1053
+ kwargs["state_dict"] = partial_state_dict
1054
+ kwargs["safe_serialization"] = self.is_safetensor_save
1055
+ kwargs.setdefault("max_shard_size", self.save_shard_size)
1056
+ super().save_pretrained(save_directory, *args, **kwargs)
1057
+ if self.is_qwen_visual:
1058
+ self.config.architectures = ["HCXVisionV2ForCausalLM"]
1059
+ else:
1060
+ self.config.architectures = ["HCXVisionForCausalLM"]
1061
+ self.config.auto_map["AutoModelForCausalLM"] = (
1062
+ "modeling_vlm.HCXVisionForCausalLM"
1063
+ )
1064
+ self.config.auto_map["AutoModelForSequenceClassification"] = (
1065
+ "modeling_vlm.HCXVisionForSequenceClassification"
1066
+ )
1067
+ self.config.save_pretrained(save_directory)
1068
+
1069
+ def get_pretrained_state_dict(self, state_dict):
1070
+ vision_key = "vision_model."
1071
+ llm_keys = ["language_model."]
1072
+ head_key = "lm_head."
1073
+
1074
+ for key in list(state_dict.keys()):
1075
+ if self.save_only_vision:
1076
+ for llm_key in llm_keys:
1077
+ if llm_key in key:
1078
+ state_dict.pop(key)
1079
+ if key.startswith(head_key):
1080
+ state_dict.pop(key)
1081
+ elif self.save_only_qformer:
1082
+ if f"{vision_key}" in key:
1083
+ state_dict.pop(key)
1084
+
1085
+ return state_dict
preprocessor.py ADDED
The diff for this file is too large to render. See raw diff
 
preprocessor_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_vlm.HCXVisionV2Processor"
4
+ },
5
+ "do_convert_rgb": true,
6
+ "do_normalize": true,
7
+ "do_rescale": true,
8
+ "do_resize": true,
9
+ "image_mean": [
10
+ 0.48145466,
11
+ 0.4578275,
12
+ 0.40821073
13
+ ],
14
+ "image_processor_type": "Qwen2VLImageProcessor",
15
+ "image_std": [
16
+ 0.26862954,
17
+ 0.26130258,
18
+ 0.27577711
19
+ ],
20
+ "max_pixels": 2073600,
21
+ "merge_size": 2,
22
+ "min_pixels": 3136,
23
+ "patch_size": 14,
24
+ "processor_class": "HCXVisionV2Processor",
25
+ "resample": 3,
26
+ "rescale_factor": 0.00392156862745098,
27
+ "size": {
28
+ "longest_edge": 12845056,
29
+ "shortest_edge": 3136
30
+ },
31
+ "temporal_patch_size": 2
32
+ }
processing_vlm.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import os
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import Qwen2_5_VLProcessor
10
+ from transformers.image_processing_utils import (
11
+ BaseImageProcessor,
12
+ BatchFeature,
13
+ get_size_dict,
14
+ )
15
+ from transformers.image_transforms import (
16
+ convert_to_rgb,
17
+ get_resize_output_image_size,
18
+ resize,
19
+ to_channel_dimension_format,
20
+ )
21
+ from transformers.image_utils import (
22
+ OPENAI_CLIP_MEAN,
23
+ OPENAI_CLIP_STD,
24
+ ChannelDimension,
25
+ ImageInput,
26
+ PILImageResampling,
27
+ get_image_size,
28
+ infer_channel_dimension_format,
29
+ is_scaled_image,
30
+ make_list_of_images,
31
+ to_numpy_array,
32
+ valid_images,
33
+ )
34
+ from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import (
35
+ Qwen2_5_VLProcessorKwargs,
36
+ )
37
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
38
+ from transformers.utils import TensorType, logging
39
+ from transformers.video_utils import VideoInput
40
+ from typing_extensions import Unpack
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ def determine_possible_resolutions(
46
+ anyres: bool, max_num_grids: int, grid_size: int, use_1x1_grid: bool = False
47
+ ):
48
+ """총 max_num_grids 이하의 possible resolution 조합을 찾아 반환합니다.
49
+ max_num_grids 가 예를 들어 4인 경우, 총 가능한 grid 조합은 [1x1, 1x2, 1x3, 1x4, 2x1, 2x2, 3x1, 4x1] 이고, 따라서 아래와 같이 계산됩니다.
50
+ >>> possible_resolutions = determine_possible_resolutions(anyres=True, max_num_grids=4, grid_size=336)
51
+ >>> print(possible_resolutions)
52
+ [[336, 336], [336, 672], [336, 1008], [336, 1344], [672, 336], [672, 672], [1008, 336], [1344, 336]]
53
+ """
54
+ possible_resolutions = []
55
+ if anyres:
56
+ assert max_num_grids > 0
57
+ for i in range(1, max_num_grids + 1):
58
+ for j in range(1, max_num_grids + 1):
59
+ if i == 1 and j == 1 and not use_1x1_grid:
60
+ continue
61
+ if i * j <= max_num_grids:
62
+ possible_resolutions.append([i, j])
63
+
64
+ possible_resolutions = [
65
+ [ys * grid_size, xs * grid_size] for ys, xs in possible_resolutions
66
+ ]
67
+
68
+ return possible_resolutions
69
+
70
+
71
+ def divide_to_grids(
72
+ image: np.array, grid_size: int, input_data_format=None
73
+ ) -> List[np.array]:
74
+ """local image 를 (grid_size x grid_size) grid 로 divide"""
75
+ grids = []
76
+ height, width = get_image_size(image, channel_dim=input_data_format)
77
+ for i in range(0, height, grid_size):
78
+ for j in range(0, width, grid_size):
79
+ if input_data_format == ChannelDimension.LAST:
80
+ grid = image[i : i + grid_size, j : j + grid_size]
81
+ else:
82
+ grid = image[:, i : i + grid_size, j : j + grid_size]
83
+ grids.append(grid)
84
+
85
+ return grids
86
+
87
+
88
+ def pad(
89
+ image: np.array,
90
+ target_size: tuple,
91
+ background_color=(127, 127, 127),
92
+ input_data_format=None,
93
+ ) -> np.array:
94
+ """image 양옆, 좌우에 padding 을 하여 target_height, target_width 만큼 키움"""
95
+ target_height, target_width = target_size
96
+ height, width = get_image_size(image, channel_dim=input_data_format)
97
+
98
+ result = np.empty((target_height, target_width, image.shape[2]), dtype=image.dtype)
99
+ for i in range(image.shape[2]):
100
+ result[..., i].fill(background_color[i])
101
+
102
+ paste_x = (target_width - width) // 2
103
+ paste_y = (target_height - height) // 2
104
+
105
+ result[paste_y : paste_y + height, paste_x : paste_x + width, :] = image
106
+
107
+ return result
108
+
109
+
110
+ def expand2square(
111
+ image: np.array,
112
+ bboxes_dict=None,
113
+ background_color=(127, 127, 127),
114
+ input_data_format=None,
115
+ ) -> np.array:
116
+ """
117
+ 새로운 canvas 를 만들어 두고, 거기에 이미지를 붙여넣는 방식으로 이미지를 정사각형으로 만드는 함수
118
+ 유의할 사항은, 이미지를 붙여 넣을 때 중앙으로 붙여넣는다는 점. 양옆 또는 위아래로 PADDING 이 들어가는 형태
119
+ Args:
120
+ pil_img: numpy array
121
+ bboxes_dict: dict, {"ocr": NDArray shape (N, 4, 2), "html": NDArray shape (N, 4, 2), ... }
122
+ `[[xtl, ytl], [xtr, ytr], [xbr, ybr], [xbl, ybl]]` 형태로 박스 형태는 통일. OCR, HTML 등 다양한 박스들을 한번에 처리 가능
123
+ background_color: tuple, RGB
124
+ # >>> _img = np.ones((80, 100), dtype=np.uint8) * 100
125
+ # >>> _bboxes_dict = {"words": np.array([[[10, 10], [20, 10], [20, 20], [10, 20]],
126
+ # ... [[30, 30], [40, 30], [40, 40], [30, 40]]])}
127
+ # >>> _img, _bboxes_dict = expand2square(_img, _bboxes_dict, (255, 255, 255))
128
+ # >>> _img.shape
129
+ # (100, 100)
130
+ # >>> guessed_ocr_bboxes = np.array([[[20, 10], [30, 10], [30, 20], [20, 20]],
131
+ # ... [[40, 30], [50, 30], [50, 40], [40, 40]]])
132
+ # >>> np.testing.assert_array_almost_equal(_bboxes_dict["words"], guessed_ocr_bboxes) is None
133
+ # True
134
+ """
135
+ height, width = get_image_size(image, channel_dim=input_data_format)
136
+ if width == height:
137
+ return image, bboxes_dict
138
+ elif width > height:
139
+ result = np.empty((width, width, image.shape[2]), dtype=image.dtype)
140
+ for i in range(image.shape[2]):
141
+ result[..., i].fill(background_color[i])
142
+
143
+ result[(width - height) // 2 : (width - height) // 2 + height, :] = image
144
+ if bboxes_dict is not None:
145
+ for key in bboxes_dict:
146
+ bboxes_dict[key][:, :, 1] += (width - height) // 2
147
+ return result, bboxes_dict
148
+ else:
149
+ result = np.empty((height, height, image.shape[2]), dtype=image.dtype)
150
+ for i in range(image.shape[2]):
151
+ result[..., i].fill(background_color[i])
152
+
153
+ result[:, (height - width) // 2 : (height - width) // 2 + width] = image
154
+ if bboxes_dict is not None:
155
+ for key in bboxes_dict:
156
+ bboxes_dict[key][:, :, 0] += (height - width) // 2
157
+ return result, bboxes_dict
158
+
159
+
160
+ def resize_longside(
161
+ image: np.array,
162
+ size: int,
163
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
164
+ data_format: Optional[Union[str, ChannelDimension]] = None,
165
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
166
+ ):
167
+ """
168
+ 장축 길이를 size 에 맞게 resize
169
+ """
170
+ height, width = get_image_size(image, channel_dim=input_data_format)
171
+
172
+ if width == height:
173
+ target_height, target_width = size, size
174
+ elif width > height:
175
+ target_width = size
176
+ target_height = math.ceil(height / width * size)
177
+ else:
178
+ target_width = math.ceil(width / height * size)
179
+ target_height = size
180
+
181
+ return resize(
182
+ image,
183
+ size=(target_height, target_width),
184
+ resample=resample,
185
+ data_format=data_format,
186
+ input_data_format=input_data_format,
187
+ )
188
+
189
+
190
+ def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
191
+ """From LLaVA-Next (https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llava_next/image_processing_llava_next.py)
192
+ Selects the best resolution from a list of possible resolutions based on the original size.
193
+ This is done by calculating the effective and wasted resolution for each possible resolution.
194
+ The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
195
+
196
+ Args:
197
+ original_size (tuple):
198
+ The original size of the image in the format (height, width).
199
+ possible_resolutions (list):
200
+ A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
201
+
202
+ Returns:
203
+ tuple: The best fit resolution in the format (height, width).
204
+ """
205
+ original_height, original_width = original_size
206
+ best_fit = None
207
+ max_effective_resolution = 0
208
+ min_wasted_resolution = float("inf")
209
+
210
+ for height, width in possible_resolutions:
211
+ scale = min(width / original_width, height / original_height)
212
+ downscaled_width, downscaled_height = int(original_width * scale), int(
213
+ original_height * scale
214
+ )
215
+ effective_resolution = min(
216
+ downscaled_width * downscaled_height, original_width * original_height
217
+ )
218
+ wasted_resolution = (width * height) - effective_resolution
219
+
220
+ if effective_resolution > max_effective_resolution or (
221
+ effective_resolution == max_effective_resolution
222
+ and wasted_resolution < min_wasted_resolution
223
+ ):
224
+ max_effective_resolution = effective_resolution
225
+ min_wasted_resolution = wasted_resolution
226
+ best_fit = (height, width)
227
+
228
+ return best_fit
229
+
230
+
231
+ def _get_local_grids_output_size(
232
+ image: np.array, target_resolution: tuple, input_data_format=None
233
+ ):
234
+ original_height, original_width = get_image_size(
235
+ image, channel_dim=input_data_format
236
+ )
237
+ target_height, target_width = target_resolution
238
+
239
+ scale_w = target_width / original_width
240
+ scale_h = target_height / original_height
241
+
242
+ if scale_w < scale_h:
243
+ new_width = target_width
244
+ new_height = min(math.ceil(original_height * scale_w), target_height)
245
+ else:
246
+ new_height = target_height
247
+ new_width = min(math.ceil(original_width * scale_h), target_width)
248
+
249
+ return new_height, new_width
250
+
251
+
252
+ def determine_anyres_num_vision_patches(
253
+ num_grids,
254
+ image_size,
255
+ grid_size,
256
+ patch_size,
257
+ possible_resolutions,
258
+ anyres=False,
259
+ unpad=True,
260
+ num_queries_vis_abstractor=0,
261
+ num_queries_vis_abstractor_slow=0,
262
+ video=False,
263
+ first_last_frames_slow=False,
264
+ is_first_or_last_frames=False,
265
+ ):
266
+ """visual tokens 수를 계산해주는 함수"""
267
+ if not anyres:
268
+ return (
269
+ num_queries_vis_abstractor
270
+ if num_queries_vis_abstractor > 0
271
+ else (grid_size // patch_size) ** 2
272
+ )
273
+
274
+ if num_queries_vis_abstractor > 0:
275
+ num_patch_per_grid = int(num_queries_vis_abstractor**0.5)
276
+ else:
277
+ num_patch_per_grid = grid_size // patch_size
278
+
279
+ num_global_per_grid = num_patch_per_grid
280
+
281
+ height, width = select_best_resolution(image_size, possible_resolutions)
282
+
283
+ num_patch_height = (height // grid_size) * num_patch_per_grid
284
+ num_patch_width = (width // grid_size) * num_patch_per_grid
285
+
286
+ if unpad:
287
+ original_height, original_width = image_size
288
+
289
+ original_aspect_ratio = original_width / original_height
290
+ current_aspect_ratio = num_patch_width / num_patch_height
291
+
292
+ if original_aspect_ratio > current_aspect_ratio:
293
+ scale_factor = num_patch_width / original_width
294
+ new_height = int(original_height * scale_factor)
295
+ padding = (num_patch_height - new_height) // 2
296
+ num_patch_height = num_patch_height - padding * 2
297
+ else:
298
+ scale_factor = num_patch_height / original_height
299
+ new_width = int(original_width * scale_factor)
300
+ padding = (num_patch_width - new_width) // 2
301
+ num_patch_width = num_patch_width - padding * 2
302
+
303
+ num_patches = num_patch_width * num_patch_height + num_patch_height
304
+ else:
305
+ num_patches = num_patch_width * num_patch_height
306
+
307
+ if num_queries_vis_abstractor_slow > 0:
308
+ if first_last_frames_slow:
309
+ if is_first_or_last_frames:
310
+ num_patches += (
311
+ num_queries_vis_abstractor_slow - num_queries_vis_abstractor
312
+ )
313
+ else:
314
+ num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
315
+ assert unpad is False
316
+
317
+ if not video:
318
+ num_patches += num_global_per_grid**2
319
+
320
+ return num_patches
321
+
322
+
323
+ class HCXVisionImageProcessor(BaseImageProcessor):
324
+ r"""
325
+ Constructs a VLM image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques for processing high resolution images.
326
+
327
+ Args:
328
+ anyres: (bool) anyres 기능을 사용할지 안할지
329
+ unpad: (bool) anyres 사용시, unpad 기능 (순수 pad 영역에 해당하는 visual tokens 은 LLM input 에서 제거) 을 사용할지 안할지
330
+ num_queries_vis_abstractor: (int) 각 grid 에 대해서 resampler 를 사용하는 경우, visual query 수
331
+ possible_resolutions: (List) anyres 기능 사용시, 가능한 resolution 조합, 예: [[336, 336], [336, 672], [672, 336]]
332
+ patch_size: (int) ViT patch size
333
+ pad_to_square: (bool) 정사각형으로 padding 을 수행할지, 안할지를 결정. False 이면 정사각형이 아니기 때문에 center crop 을 거쳐 ViT 의 입력으로 들어감
334
+ """
335
+
336
+ model_input_names = ["pixel_values"]
337
+
338
+ def __init__(
339
+ self,
340
+ do_resize: bool = True,
341
+ size: Dict[str, int] = None,
342
+ anyres: bool = False,
343
+ unpad: bool = False,
344
+ num_queries_vis_abstractor: int = 0,
345
+ possible_resolutions: List = [],
346
+ patch_size: int = 14,
347
+ pad_to_square: bool = True,
348
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
349
+ do_center_crop: bool = True,
350
+ crop_size: Dict[str, int] = None,
351
+ do_rescale: bool = True,
352
+ rescale_factor: Union[int, float] = 1 / 255,
353
+ do_normalize: bool = True,
354
+ image_mean: Optional[Union[float, List[float]]] = None,
355
+ image_std: Optional[Union[float, List[float]]] = None,
356
+ do_convert_rgb: bool = True,
357
+ **kwargs,
358
+ ) -> None:
359
+ super().__init__(**kwargs)
360
+ size = size if size is not None else {"shortest_edge": 336}
361
+ size = get_size_dict(size, default_to_square=False)
362
+ crop_size = (
363
+ crop_size if crop_size is not None else {"height": 336, "width": 336}
364
+ )
365
+ crop_size = get_size_dict(
366
+ crop_size, default_to_square=True, param_name="crop_size"
367
+ )
368
+
369
+ self.do_resize = do_resize
370
+ self.size = size
371
+ self.anyres = anyres
372
+ self.unpad = unpad
373
+ self.num_queries_vis_abstractor = num_queries_vis_abstractor
374
+ self.possible_resolutions = [
375
+ _resolution for _resolution in possible_resolutions
376
+ ]
377
+ self.patch_size = patch_size
378
+ self.pad_to_square = pad_to_square
379
+ self.resample = resample
380
+ self.do_center_crop = do_center_crop
381
+ self.crop_size = crop_size
382
+ self.do_rescale = do_rescale
383
+ self.rescale_factor = rescale_factor
384
+ self.do_normalize = do_normalize
385
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
386
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
387
+ self.do_convert_rgb = do_convert_rgb
388
+
389
+ def resize(
390
+ self,
391
+ image: np.ndarray,
392
+ size: Dict[str, int],
393
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
394
+ data_format: Optional[Union[str, ChannelDimension]] = None,
395
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
396
+ **kwargs,
397
+ ) -> np.ndarray:
398
+ default_to_square = True
399
+ if "shortest_edge" in size:
400
+ size = size["shortest_edge"]
401
+ default_to_square = False
402
+ elif "height" in size and "width" in size:
403
+ size = (size["height"], size["width"])
404
+ else:
405
+ raise ValueError(
406
+ "Size must contain either 'shortest_edge' or 'height' and 'width'."
407
+ )
408
+
409
+ output_size = get_resize_output_image_size(
410
+ image,
411
+ size=size,
412
+ default_to_square=default_to_square,
413
+ input_data_format=input_data_format,
414
+ )
415
+
416
+ return resize(
417
+ image,
418
+ size=output_size,
419
+ resample=resample,
420
+ data_format=data_format,
421
+ input_data_format=input_data_format,
422
+ **kwargs,
423
+ )
424
+
425
+ def _preprocess(
426
+ self,
427
+ images: ImageInput,
428
+ do_resize: bool = None,
429
+ size: Dict[str, int] = None,
430
+ resample: PILImageResampling = None,
431
+ do_center_crop: bool = None,
432
+ crop_size: int = None,
433
+ do_rescale: bool = None,
434
+ rescale_factor: float = None,
435
+ do_normalize: bool = None,
436
+ image_mean: Optional[Union[float, List[float]]] = None,
437
+ image_std: Optional[Union[float, List[float]]] = None,
438
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
439
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
440
+ ) -> Image.Image:
441
+ images = make_list_of_images(images)
442
+
443
+ if do_resize:
444
+ images = [
445
+ self.resize(
446
+ image=image,
447
+ size=size,
448
+ resample=resample,
449
+ input_data_format=input_data_format,
450
+ )
451
+ for image in images
452
+ ]
453
+
454
+ if do_center_crop:
455
+ images = [
456
+ self.center_crop(
457
+ image=image, size=crop_size, input_data_format=input_data_format
458
+ )
459
+ for image in images
460
+ ]
461
+
462
+ if do_rescale:
463
+ images = [
464
+ self.rescale(
465
+ image=image,
466
+ scale=rescale_factor,
467
+ input_data_format=input_data_format,
468
+ )
469
+ for image in images
470
+ ]
471
+
472
+ if do_normalize:
473
+ images = [
474
+ self.normalize(
475
+ image=image,
476
+ mean=image_mean,
477
+ std=image_std,
478
+ input_data_format=input_data_format,
479
+ )
480
+ for image in images
481
+ ]
482
+
483
+ images = [
484
+ to_channel_dimension_format(
485
+ image, data_format, input_channel_dim=input_data_format
486
+ )
487
+ for image in images
488
+ ]
489
+
490
+ return images
491
+
492
+ def _resize_for_local_grids(
493
+ self,
494
+ image: np.array,
495
+ target_resolution: tuple,
496
+ resample,
497
+ input_data_format: ChannelDimension,
498
+ ) -> np.array:
499
+ new_height, new_width = _get_local_grids_output_size(
500
+ image, target_resolution, input_data_format
501
+ )
502
+
503
+ resized_image = resize(
504
+ image,
505
+ (new_height, new_width),
506
+ resample=resample,
507
+ input_data_format=input_data_format,
508
+ )
509
+
510
+ return resized_image
511
+
512
+ def _pad_for_patching(
513
+ self,
514
+ image: np.array,
515
+ target_resolution: tuple,
516
+ input_data_format: ChannelDimension,
517
+ ) -> np.array:
518
+ """
519
+ Pad an image to a target resolution while maintaining aspect ratio.
520
+ """
521
+ target_height, target_width = target_resolution
522
+
523
+ background_color = tuple(int(x * 255) for x in self.image_mean)
524
+ padded_image = pad(
525
+ image,
526
+ target_size=(target_height, target_width),
527
+ background_color=background_color,
528
+ input_data_format=input_data_format,
529
+ )
530
+
531
+ return padded_image
532
+
533
+ def get_image_grids(
534
+ self,
535
+ image: np.array,
536
+ possible_resolutions,
537
+ grid_size: int,
538
+ resample: PILImageResampling,
539
+ data_format: ChannelDimension,
540
+ input_data_format: ChannelDimension,
541
+ ) -> List[np.array]:
542
+ if not isinstance(possible_resolutions, list):
543
+ raise ValueError(
544
+ "possible_resolutions must be a list of possible resolutions."
545
+ )
546
+
547
+ image_size = get_image_size(image, channel_dim=input_data_format)
548
+ best_resolution = select_best_resolution(image_size, possible_resolutions)
549
+ resized_image = self._resize_for_local_grids(
550
+ image,
551
+ best_resolution,
552
+ resample=resample,
553
+ input_data_format=input_data_format,
554
+ )
555
+ padded_image = self._pad_for_patching(
556
+ resized_image, best_resolution, input_data_format=input_data_format
557
+ )
558
+ local_grids = divide_to_grids(
559
+ padded_image, grid_size=grid_size, input_data_format=input_data_format
560
+ )
561
+
562
+ local_grids = [
563
+ to_channel_dimension_format(
564
+ grid, channel_dim=data_format, input_channel_dim=input_data_format
565
+ )
566
+ for grid in local_grids
567
+ ]
568
+
569
+ return local_grids
570
+
571
+ def preprocess(
572
+ self,
573
+ images: ImageInput,
574
+ do_resize: bool = None,
575
+ size: Dict[str, int] = None,
576
+ anyres: bool = None,
577
+ unpad: bool = None,
578
+ video: bool = None,
579
+ num_queries_vis_abstractor: int = None,
580
+ possible_resolutions: List = None,
581
+ patch_size: int = None,
582
+ pad_to_square: bool = None,
583
+ resample: PILImageResampling = None,
584
+ do_center_crop: bool = None,
585
+ crop_size: int = None,
586
+ do_rescale: bool = None,
587
+ rescale_factor: float = None,
588
+ do_normalize: bool = None,
589
+ image_mean: Optional[Union[float, List[float]]] = None,
590
+ image_std: Optional[Union[float, List[float]]] = None,
591
+ do_convert_rgb: bool = None,
592
+ return_tensors: Optional[Union[str, TensorType]] = None,
593
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
594
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
595
+ return_dummy_image: bool = False,
596
+ num_queries_vis_abstractor_slow: int = 0,
597
+ first_last_frames_slow: bool = False,
598
+ is_first_or_last_frames: bool = False,
599
+ ):
600
+ """
601
+ HCXVisionImageProcessor 로 image tensor, original image size (width, height), visual tokens
602
+
603
+ :return pixel_values: List of 4D tensor 로 image tensor
604
+ :return image_sizes: List of Dict 로 image width, height [{"width": image 1 의 width, "height": image 1 의 height}, {"width": image 2 의 width, "height": image 2 의 height}, ...]
605
+ :return vision_query_lengths: List of int 로 각 image 가 LLM 입력으로 전달될때 변환되는 visual token 수
606
+ """
607
+ do_resize = do_resize if do_resize is not None else self.do_resize
608
+ size = size if size is not None else self.size
609
+ size = get_size_dict(size, param_name="size", default_to_square=False)
610
+ anyres = anyres if anyres is not None else self.anyres
611
+ unpad = unpad if unpad is not None else self.unpad
612
+ if video:
613
+ unpad = False
614
+ num_queries_vis_abstractor = (
615
+ num_queries_vis_abstractor
616
+ if num_queries_vis_abstractor is not None
617
+ else self.num_queries_vis_abstractor
618
+ )
619
+ possible_resolutions = (
620
+ possible_resolutions
621
+ if possible_resolutions is not None
622
+ else self.possible_resolutions
623
+ )
624
+ patch_size = patch_size if patch_size is not None else self.patch_size
625
+ pad_to_square = (
626
+ pad_to_square if pad_to_square is not None else self.pad_to_square
627
+ )
628
+ resample = resample if resample is not None else self.resample
629
+ do_center_crop = (
630
+ do_center_crop if do_center_crop is not None else self.do_center_crop
631
+ )
632
+ crop_size = crop_size if crop_size is not None else self.crop_size
633
+ crop_size = get_size_dict(
634
+ crop_size, param_name="crop_size", default_to_square=True
635
+ )
636
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
637
+ rescale_factor = (
638
+ rescale_factor if rescale_factor is not None else self.rescale_factor
639
+ )
640
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
641
+ image_mean = image_mean if image_mean is not None else self.image_mean
642
+ image_std = image_std if image_std is not None else self.image_std
643
+ do_convert_rgb = (
644
+ do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
645
+ )
646
+
647
+ if return_dummy_image:
648
+ images = Image.new("RGB", (224, 224), (0, 0, 0))
649
+
650
+ images = make_list_of_images(images)
651
+
652
+ if not valid_images(images):
653
+ raise ValueError(
654
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
655
+ "torch.Tensor, tf.Tensor or jax.ndarray."
656
+ )
657
+
658
+ if do_convert_rgb:
659
+ images = [convert_to_rgb(image) for image in images]
660
+
661
+ images = [to_numpy_array(image) for image in images]
662
+
663
+ if is_scaled_image(images[0]) and do_rescale:
664
+ logger.warning_once(
665
+ "It looks like you are trying to rescale already rescaled images. If the input"
666
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
667
+ )
668
+
669
+ if input_data_format is None:
670
+ input_data_format = infer_channel_dimension_format(images[0])
671
+
672
+ new_images = []
673
+ image_sizes = [
674
+ get_image_size(image, channel_dim=input_data_format) for image in images
675
+ ]
676
+ vision_query_lengths = []
677
+
678
+ assert crop_size["height"] == crop_size["width"]
679
+
680
+ if anyres:
681
+ anyres_global_images = copy.deepcopy(images)
682
+ if pad_to_square:
683
+ background_color = tuple(int(x * 255) for x in self.image_mean)
684
+ anyres_global_images = [
685
+ resize_longside(
686
+ copy.deepcopy(image),
687
+ size["shortest_edge"],
688
+ resample,
689
+ input_data_format,
690
+ )
691
+ for image in anyres_global_images
692
+ ]
693
+ anyres_global_images = [
694
+ expand2square(
695
+ image,
696
+ background_color=background_color,
697
+ input_data_format=input_data_format,
698
+ )[0]
699
+ for image in anyres_global_images
700
+ ]
701
+ else:
702
+ anyres_global_images = [
703
+ self.resize(
704
+ image=image,
705
+ size={
706
+ "height": size["shortest_edge"],
707
+ "width": size["shortest_edge"],
708
+ },
709
+ resample=resample,
710
+ input_data_format=input_data_format,
711
+ )
712
+ for image in anyres_global_images
713
+ ]
714
+ else:
715
+ anyres_global_images = [None for _ in range(len(images))]
716
+ if pad_to_square:
717
+ background_color = tuple(int(x * 255) for x in self.image_mean)
718
+ images = [
719
+ resize_longside(
720
+ image, size["shortest_edge"], resample, input_data_format
721
+ )
722
+ for image in images
723
+ ]
724
+ images = [
725
+ expand2square(
726
+ image,
727
+ background_color=background_color,
728
+ input_data_format=input_data_format,
729
+ )[0]
730
+ for image in images
731
+ ]
732
+
733
+ for image, anyres_global_image, image_size in zip(
734
+ images, anyres_global_images, image_sizes
735
+ ):
736
+ if anyres:
737
+ image_grids = self.get_image_grids(
738
+ image,
739
+ possible_resolutions,
740
+ grid_size=crop_size["height"],
741
+ resample=resample,
742
+ data_format=input_data_format,
743
+ input_data_format=input_data_format,
744
+ )
745
+ if not video:
746
+ image_grids = [anyres_global_image] + image_grids
747
+ else:
748
+ image_grids = [image]
749
+
750
+ pixel_values = self._preprocess(
751
+ image_grids,
752
+ do_resize=do_resize,
753
+ size=size,
754
+ resample=resample,
755
+ do_center_crop=do_center_crop,
756
+ crop_size=crop_size,
757
+ do_rescale=do_rescale,
758
+ rescale_factor=rescale_factor,
759
+ do_normalize=do_normalize,
760
+ image_mean=image_mean,
761
+ image_std=image_std,
762
+ data_format=data_format,
763
+ input_data_format=input_data_format,
764
+ )
765
+
766
+ pixel_values = np.array(pixel_values)
767
+ new_images.append(pixel_values)
768
+
769
+ num_grids = pixel_values.shape[0]
770
+
771
+ vision_query_length = determine_anyres_num_vision_patches(
772
+ num_grids=num_grids,
773
+ image_size=image_size,
774
+ grid_size=crop_size["height"],
775
+ patch_size=patch_size,
776
+ possible_resolutions=possible_resolutions,
777
+ anyres=anyres,
778
+ unpad=unpad,
779
+ num_queries_vis_abstractor=num_queries_vis_abstractor,
780
+ num_queries_vis_abstractor_slow=num_queries_vis_abstractor_slow,
781
+ video=video,
782
+ first_last_frames_slow=first_last_frames_slow,
783
+ is_first_or_last_frames=is_first_or_last_frames,
784
+ )
785
+
786
+ vision_query_lengths.append(vision_query_length)
787
+
788
+ if return_dummy_image:
789
+ vision_query_lengths = []
790
+
791
+ data = {
792
+ "pixel_values": [torch.tensor(new_image) for new_image in new_images],
793
+ "image_sizes": [
794
+ {"width": image_size[1], "height": image_size[0]}
795
+ for image_size in image_sizes
796
+ ],
797
+ "vision_query_lengths": vision_query_lengths,
798
+ }
799
+
800
+ return BatchFeature(data=data)
801
+
802
+ def save_pretrained(
803
+ self,
804
+ save_directory: Union[str, os.PathLike],
805
+ *args,
806
+ **kwargs,
807
+ ):
808
+ self.register_for_auto_class()
809
+ super().save_pretrained(save_directory, *args, **kwargs)
810
+
811
+
812
+ class HCXVisionV2Processor(Qwen2_5_VLProcessor):
813
+ attributes = ["image_processor", "tokenizer", "video_processor"]
814
+ image_processor_class = "AutoImageProcessor"
815
+ video_processor_class = "AutoVideoProcessor"
816
+ tokenizer_class = (
817
+ "GPT2Tokenizer",
818
+ "GPT2TokenizerFast",
819
+ "PreTrainedTokenizer",
820
+ "PreTrainedTokenizerFast",
821
+ )
822
+
823
+ def __init__(
824
+ self,
825
+ image_processor=None,
826
+ tokenizer=None,
827
+ video_processor=None,
828
+ chat_template=None,
829
+ **kwargs,
830
+ ):
831
+ self.tokenizer = tokenizer
832
+ super().__init__(
833
+ image_processor,
834
+ tokenizer,
835
+ video_processor,
836
+ chat_template=self.tokenizer.chat_template,
837
+ )
838
+
839
+ def save_pretrained(
840
+ self,
841
+ save_directory: Union[str, os.PathLike],
842
+ *args,
843
+ **kwargs,
844
+ ):
845
+ self.register_for_auto_class()
846
+ super().save_pretrained(save_directory, *args, **kwargs)
847
+
848
+ def __call__(
849
+ self,
850
+ images: ImageInput = None,
851
+ text: Union[
852
+ TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
853
+ ] = None,
854
+ videos: VideoInput = None,
855
+ **kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
856
+ ) -> BatchFeature:
857
+ """
858
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
859
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
860
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
861
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
862
+
863
+ Args:
864
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
865
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
866
+ tensor. Both channels-first and channels-last formats are supported.
867
+ text (`str`, `list[str]`, `list[list[str]]`):
868
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
869
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
870
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
871
+ videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
872
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
873
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
874
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
875
+ If set, will return tensors of a particular framework. Acceptable values are:
876
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
877
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
878
+ - `'np'`: Return NumPy `np.ndarray` objects.
879
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
880
+
881
+ Returns:
882
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
883
+
884
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
885
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
886
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
887
+ `None`).
888
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
889
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
890
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
891
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
892
+ """
893
+ output_kwargs = self._merge_kwargs(
894
+ Qwen2_5_VLProcessorKwargs,
895
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
896
+ **kwargs,
897
+ )
898
+
899
+ image_inputs = videos_inputs = {}
900
+ if images is not None:
901
+ image_inputs = self.image_processor(
902
+ images=images, **output_kwargs["images_kwargs"]
903
+ )
904
+ image_grid_thw = image_inputs["image_grid_thw"]
905
+
906
+ if videos is not None:
907
+ videos_inputs = self.video_processor(
908
+ videos=videos, **output_kwargs["videos_kwargs"]
909
+ )
910
+ video_grid_thw = videos_inputs["video_grid_thw"]
911
+
912
+ if not isinstance(text, list):
913
+ text = [text]
914
+
915
+ text = text.copy()
916
+
917
+ if images is not None:
918
+ merge_length = self.image_processor.merge_size**2
919
+ index = 0
920
+ for i in range(len(text)):
921
+ while self.image_token in text[i]:
922
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
923
+ text[i] = text[i].replace(
924
+ self.image_token, "<|placeholder|>" * num_image_tokens, 1
925
+ )
926
+ text[i] = text[i].replace(
927
+ '{"resolution": [w, h]}',
928
+ '{"resolution": ' + str(list(images[i].size)) + "}",
929
+ )
930
+ index += 1
931
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
932
+
933
+ if videos is not None:
934
+ merge_length = self.video_processor.merge_size**2
935
+ index = 0
936
+ for i in range(len(text)):
937
+ while self.video_token in text[i]:
938
+ num_video_tokens = video_grid_thw[index].prod() // merge_length
939
+ text[i] = text[i].replace(
940
+ self.video_token, "<|placeholder|>" * num_video_tokens, 1
941
+ )
942
+ index += 1
943
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
944
+
945
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
946
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop(
947
+ "return_mm_token_type_ids", False
948
+ )
949
+ text_inputs = self.tokenizer(
950
+ text, **output_kwargs["text_kwargs"], return_tensors=None
951
+ )
952
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
953
+
954
+ if return_mm_token_type_ids:
955
+ array_ids = np.array(text_inputs["input_ids"])
956
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
957
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
958
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
959
+
960
+ return BatchFeature(
961
+ data={**text_inputs, **image_inputs, **videos_inputs},
962
+ tensor_type=return_tensors,
963
+ )
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_vlm.HCXVisionV2Processor"
4
+ },
5
+ "processor_class": "HCXVisionV2Processor"
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|im_end|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "sep_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
ta_tok.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import inspect
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torchvision.transforms import Resize
9
+ from transformers import AutoConfig, AutoModel, Siglip2VisionConfig, Siglip2VisionModel
10
+
11
+
12
+ def models_make(model_spec, args=None, load_sd=False) -> torch.nn.Module:
13
+ if args is not None:
14
+ model_args = copy.deepcopy(model_spec["args"])
15
+ model_args.update(args)
16
+ else:
17
+ model_args = model_spec["args"]
18
+ model_params = inspect.signature(models[model_spec["name"]]).parameters
19
+ if "kwargs" not in model_params:
20
+ model_args = {k: v for k, v in model_args.items() if k in model_params}
21
+ model = models[model_spec["name"]](**model_args)
22
+ if load_sd:
23
+ if (
24
+ ("abs_pe" in model_spec["sd"])
25
+ and hasattr(model, "abs_pe")
26
+ and model_spec["sd"]["abs_pe"].shape != model.abs_pe.shape
27
+ ):
28
+ del model_spec["sd"]["abs_pe"]
29
+ msg = model.load_state_dict(model_spec["sd"], strict=False)
30
+ print(msg)
31
+ return model
32
+
33
+
34
+ class Bottleneck(nn.Module):
35
+ def __init__(
36
+ self,
37
+ bottleneck_dim: int,
38
+ input_dim: int,
39
+ output_dim: int,
40
+ token_nums: int,
41
+ regularizer=None,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+ self.token_nums = token_nums
46
+ self.input_dim = input_dim
47
+ self.output_dim = output_dim
48
+ if bottleneck_dim > 0:
49
+ self.bottleneck_dim = bottleneck_dim
50
+ else:
51
+ assert (
52
+ self.input_dim == self.output_dim
53
+ ), "input_dim and output_dim must be the same when bottleneck_dim is not specified"
54
+ self.bottleneck_dim = self.input_dim
55
+
56
+ self.project_dim = self.bottleneck_dim
57
+
58
+ if self.bottleneck_dim > 0:
59
+ self.in_linear = nn.Linear(self.input_dim, self.project_dim)
60
+ self.out_linear = nn.Linear(self.bottleneck_dim, self.output_dim)
61
+ else:
62
+ self.in_linear = self.out_linear = lambda x: x
63
+
64
+ regularizer["args"]["dim"] = self.bottleneck_dim
65
+ regularizer["args"]["token_nums"] = self.token_nums
66
+ self.regularizer = models_make(regularizer)
67
+
68
+ def project_in(self, x):
69
+ assert len(x.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
70
+ z = self.in_linear(x)
71
+ return z
72
+
73
+ def project_out(self, z_cat):
74
+ z = self.out_linear(z_cat)
75
+ return z
76
+
77
+ def decode(self, bottleneck_rep):
78
+ regularized_z = self.regularizer.decode(bottleneck_rep)
79
+ return self.project_out(regularized_z)
80
+
81
+ def forward(self, x):
82
+ z = self.project_in(x)
83
+ projected_z = z
84
+ regularized_output = self.regularizer(z)
85
+ x_hat = self.project_out(regularized_output["regularized_z"])
86
+ bottleneck_rep = regularized_output.pop("bottleneck_rep")
87
+ return {
88
+ "output": x_hat,
89
+ "bottleneck_rep": bottleneck_rep,
90
+ "projected_z": projected_z,
91
+ **regularized_output,
92
+ }
93
+
94
+
95
+ class SimVectorQuantizer(nn.Module):
96
+ def __init__(
97
+ self,
98
+ dim,
99
+ codebook_size,
100
+ l2_normalized=False,
101
+ same_index_shape=True,
102
+ stochastic=False,
103
+ stochastic_temperature=1.0,
104
+ **kwargs,
105
+ ):
106
+ super().__init__()
107
+ self.codebook_size = codebook_size
108
+ self.dim = dim
109
+ assert isinstance(l2_normalized, bool)
110
+ self.l2_normalized = l2_normalized
111
+ self.stochastic = stochastic
112
+ self.eval_deterministic = False
113
+ self.default_stochastic_temperature = stochastic_temperature
114
+
115
+ if self.stochastic:
116
+ if stochastic_temperature > 0:
117
+ self.stochastic_temperature_inv = 1 / stochastic_temperature
118
+ else:
119
+ self.stochastic_temperature_inv = nn.Parameter(torch.tensor(10.0))
120
+
121
+ self.embedding = nn.Embedding(self.codebook_size, self.dim)
122
+ self.embedding_proj = nn.Linear(self.dim, self.dim)
123
+
124
+ self.same_index_shape = same_index_shape
125
+
126
+ def set_eval_deterministic(self, deterministic=True):
127
+ self.eval_deterministic = deterministic
128
+
129
+ def set_stochastic_temperature(self, temperature):
130
+ self.stochastic_temperature_inv = 1 / temperature
131
+
132
+ @torch.autocast(device_type="cuda", enabled=False)
133
+ def get_emb(self):
134
+ emb = self.embedding_proj(self.embedding.weight)
135
+ if self.l2_normalized:
136
+ emb = F.normalize(emb, p=2, dim=-1)
137
+ return emb
138
+
139
+ @torch.autocast(device_type="cuda", enabled=False)
140
+ def forward(self, z):
141
+ emb = self.get_emb()
142
+ z = z.to(emb)
143
+ assert len(z.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
144
+ if self.l2_normalized:
145
+ z = F.normalize(z, p=2, dim=-1)
146
+
147
+ z_flattened = rearrange(z, "b n d -> (b n) d")
148
+
149
+ if self.stochastic:
150
+ assert self.l2_normalized, "Stochastic sampling requires l2 normalization"
151
+ cos_sim = torch.einsum("bd,nd->bn", z_flattened, emb)
152
+ probs = F.softmax(cos_sim * self.stochastic_temperature_inv, dim=-1)
153
+ if self.eval_deterministic and not self.training:
154
+ q_indices = torch.argmax(probs, dim=-1)
155
+ else:
156
+ q_indices = torch.multinomial(probs, 1).squeeze(-1)
157
+ else:
158
+ d = (
159
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
160
+ + torch.sum(emb**2, dim=1)
161
+ - 2
162
+ * torch.einsum("bd,dn->bn", z_flattened, rearrange(emb, "n d -> d n"))
163
+ )
164
+ q_indices = torch.argmin(d, dim=1)
165
+
166
+ quantized = F.embedding(
167
+ q_indices,
168
+ emb,
169
+ self.embedding.padding_idx,
170
+ self.embedding.max_norm,
171
+ self.embedding.norm_type,
172
+ self.embedding.scale_grad_by_freq,
173
+ self.embedding.sparse,
174
+ ).view(z.shape)
175
+
176
+ quantized = z + (quantized - z).detach()
177
+
178
+ if self.same_index_shape:
179
+ q_indices = q_indices.reshape(quantized.shape[0], quantized.shape[1])
180
+
181
+ return_dict = {
182
+ "unregularized_z": z,
183
+ "emb": emb,
184
+ "regularized_z": quantized,
185
+ "bottleneck_rep": q_indices,
186
+ }
187
+ return return_dict
188
+
189
+ def get_codebook_entry(self, indices, shape=None):
190
+ indices_shape = indices.shape
191
+ indices_flatten = rearrange(indices, "... -> (...)")
192
+
193
+ emb = self.get_emb()
194
+ z_q = F.embedding(indices_flatten, emb)
195
+ if self.l2_normalized:
196
+ z_q = F.normalize(z_q, p=2, dim=-1)
197
+
198
+ if shape is not None:
199
+ z_q = z_q.reshape(shape)
200
+ else:
201
+ z_q = z_q.reshape([*indices_shape, self.dim])
202
+ return z_q
203
+
204
+ def decode(self, indices):
205
+ return self.get_codebook_entry(indices)
206
+
207
+
208
+ models = {"simvq": SimVectorQuantizer, "bottleneck": Bottleneck}
209
+
210
+
211
+ class ScalingLayer(nn.Module):
212
+ def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
213
+ super().__init__()
214
+ self.register_buffer("shift", torch.Tensor(mean)[None, :, None, None])
215
+ self.register_buffer("scale", torch.Tensor(std)[None, :, None, None])
216
+
217
+ def forward(self, inp):
218
+ return (inp - self.shift) / self.scale
219
+
220
+ def inv(self, inp):
221
+ return inp * self.scale + self.shift
222
+
223
+
224
+ class TextAlignedTokenizer(nn.Module):
225
+ def __init__(
226
+ self,
227
+ bottleneck,
228
+ bottleneck_token_num=256,
229
+ input_size=384,
230
+ teacher="google/siglip2-so400m-patch14-384",
231
+ input_type="quant",
232
+ pool_scale=1,
233
+ decoder_depth=3,
234
+ select_layer_id=-2,
235
+ *args,
236
+ **kwargs,
237
+ ):
238
+ super().__init__()
239
+ self.input_size = input_size
240
+ self.bottleneck_token_num = bottleneck_token_num
241
+ self.teacher = teacher
242
+ self.input_type = input_type
243
+ self.pool_scale = pool_scale
244
+ self.decoder_depth = decoder_depth
245
+ self.select_layer_id = select_layer_id
246
+
247
+ self.bottleneck_dim = bottleneck["args"]["bottleneck_dim"]
248
+
249
+ self.encoder_config = AutoConfig.from_pretrained(teacher)
250
+ self.encoder = AutoModel.from_config(self.encoder_config).vision_model
251
+
252
+ self.encoder_hidden_dim = self.encoder.config.hidden_size
253
+
254
+ self.decoder_config = Siglip2VisionConfig()
255
+ self.decoder_config.update(
256
+ {
257
+ "patch_size": 1,
258
+ "num_hidden_layers": self.decoder_depth,
259
+ "num_channels": self.bottleneck_dim,
260
+ "hidden_size": self.encoder_hidden_dim,
261
+ }
262
+ )
263
+ self.decoder = Siglip2VisionModel(self.decoder_config)
264
+
265
+ self.encode_task_layer = nn.Sequential(
266
+ nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim), nn.Tanh()
267
+ )
268
+ self.decode_task_layer = nn.Sequential(
269
+ nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
270
+ nn.Tanh(),
271
+ nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
272
+ )
273
+
274
+ bottleneck_args = {
275
+ "token_nums": self.bottleneck_token_num,
276
+ "input_dim": self.encoder_hidden_dim,
277
+ "output_dim": self.bottleneck_dim,
278
+ }
279
+ self.bottleneck = models_make(bottleneck, args=bottleneck_args)
280
+
281
+ self.scale_layer = ScalingLayer(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
282
+ self.image_resize = Resize((self.input_size, self.input_size))
283
+
284
+ def set_vq_eval_deterministic(self, deterministic=True):
285
+ self.bottleneck.regularizer.set_eval_deterministic(deterministic)
286
+
287
+ @property
288
+ def device(self):
289
+ return next(self.parameters()).device
290
+
291
+ @property
292
+ def dtype(self):
293
+ return next(self.parameters()).dtype
294
+
295
+ @classmethod
296
+ def from_checkpoint(cls, ckpt, load_teacher=True, **kwargs):
297
+ ckpt = torch.load(ckpt, map_location="cpu", weights_only=False)
298
+ ckpt_kwargs = ckpt["model"]["args"]
299
+ print(ckpt_kwargs)
300
+ model = cls(**kwargs, **ckpt_kwargs)
301
+ sd = ckpt["model"]["sd"]
302
+ if not load_teacher:
303
+ sd = {k: v for k, v in sd.items() if not k.startswith("teacher")}
304
+ model.load_state_dict(sd, strict=True)
305
+ return model
306
+
307
+ def encode(self, x, **kwargs):
308
+ if x.ndim == 5:
309
+ x = rearrange(x, "b c t h w -> (b t) c h w")
310
+ x = self.scale_layer(x)
311
+ if tuple(x.shape[-2:]) != (self.input_size, self.input_size):
312
+ x = self.image_resize(x)
313
+ vq_feats = self.encoder(x, output_hidden_states=True).hidden_states[
314
+ self.select_layer_id
315
+ ]
316
+
317
+ pool_scale = self.pool_scale
318
+ pool_scale = kwargs.get("pool_scale", pool_scale)
319
+ if pool_scale != 1:
320
+ vq_feats = self.avg_pool(vq_feats, pool_scale)
321
+ vq_feats = self.encode_task_layer(vq_feats.to(x))
322
+
323
+ bottleneck_out = self.bottleneck(vq_feats)
324
+ z = bottleneck_out.pop("output")
325
+
326
+ return {
327
+ "encoded": z,
328
+ "pool_scale": pool_scale,
329
+ "vq_feats": vq_feats,
330
+ **bottleneck_out,
331
+ }
332
+
333
+ def avg_pool(self, z, pool_scale=1):
334
+ if z.ndim == 3:
335
+ b, n, c = z.shape
336
+ p = int(n**0.5)
337
+ z = rearrange(z, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
338
+ else:
339
+ b, c, p, _ = z.shape
340
+ p_s = int(p // pool_scale)
341
+ z = F.avg_pool2d(
342
+ z, kernel_size=(pool_scale, pool_scale), stride=(pool_scale, pool_scale)
343
+ ).contiguous()
344
+ z = rearrange(z, "b c p1 p2 -> b (p1 p2) c")
345
+ return z
346
+
347
+ def decode(self, z):
348
+ if z.ndim == 4:
349
+ z = rearrange(z, "b c p1 p2 -> b (p1 p2) c")
350
+ attention_mask = torch.ones(z.shape[:2], dtype=torch.int, device=z.device)
351
+ p = int(z.shape[1] ** 0.5)
352
+ spatial_shape = torch.tensor([[p, p]] * z.shape[0], device=self.device)
353
+ z = self.decoder(
354
+ z, attention_mask, spatial_shape, output_hidden_states=True
355
+ ).last_hidden_state
356
+ z = self.decode_task_layer(z)
357
+ return z
358
+
359
+ def decode_from_bottleneck(self, bottleneck_rep):
360
+ z = self.bottleneck.decode(bottleneck_rep)
361
+ p = int(z.shape[1] ** 0.5)
362
+ z = rearrange(z, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
363
+ return self.decode(z)
364
+
365
+ def forward(self, data, **kwargs):
366
+ encode_output = self.encode(data, **kwargs)
367
+ vq_feats = encode_output["encoded"]
368
+ p = int(vq_feats.shape[1] ** 0.5)
369
+ vq_feats = rearrange(vq_feats, "b (h w) c -> b c h w", h=p, w=p)
370
+ pred_feats = self.decode(vq_feats)
371
+
372
+ if self.input_type == "quant":
373
+ z = encode_output["regularized_z"]
374
+ elif self.input_type == "indices":
375
+ z = encode_output["bottleneck_rep"]
376
+ elif self.input_type == "rec":
377
+ z = pred_feats
378
+ encode_output["encoded"] = z
379
+ return encode_output
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:666f303c324b9b2e2e8f13950cd44a18896a6fc1a70aae70583a77663d0ebe31
3
+ size 23621510
tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d7bac0a38c3af8f44d4b3b23d536111c1493fea74d3e7e2a71d804f63dada55
3
+ size 13220225
video_preprocessor_config.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_valid_kwargs_names": [
3
+ "do_convert_rgb",
4
+ "do_resize",
5
+ "size",
6
+ "size_divisor",
7
+ "default_to_square",
8
+ "resample",
9
+ "do_rescale",
10
+ "rescale_factor",
11
+ "do_normalize",
12
+ "image_mean",
13
+ "image_std",
14
+ "do_pad",
15
+ "do_center_crop",
16
+ "crop_size",
17
+ "data_format",
18
+ "input_data_format",
19
+ "device",
20
+ "min_pixels",
21
+ "max_pixels",
22
+ "patch_size",
23
+ "temporal_patch_size",
24
+ "merge_size"
25
+ ],
26
+ "auto_map": {
27
+ "AutoProcessor": "processing_vlm.HCXVisionV2Processor"
28
+ },
29
+ "crop_size": null,
30
+ "data_format": "channels_first",
31
+ "default_to_square": true,
32
+ "device": null,
33
+ "do_center_crop": null,
34
+ "do_convert_rgb": true,
35
+ "do_normalize": true,
36
+ "do_pad": null,
37
+ "do_rescale": true,
38
+ "do_resize": true,
39
+ "image_mean": [
40
+ 0.48145466,
41
+ 0.4578275,
42
+ 0.40821073
43
+ ],
44
+ "image_processor_type": "Qwen2VLImageProcessor",
45
+ "image_std": [
46
+ 0.26862954,
47
+ 0.26130258,
48
+ 0.27577711
49
+ ],
50
+ "input_data_format": null,
51
+ "max_pixels": 12845056,
52
+ "merge_size": 2,
53
+ "min_pixels": 3136,
54
+ "model_valid_processing_keys": [
55
+ "do_convert_rgb",
56
+ "do_resize",
57
+ "size",
58
+ "size_divisor",
59
+ "default_to_square",
60
+ "resample",
61
+ "do_rescale",
62
+ "rescale_factor",
63
+ "do_normalize",
64
+ "image_mean",
65
+ "image_std",
66
+ "do_pad",
67
+ "do_center_crop",
68
+ "crop_size",
69
+ "data_format",
70
+ "input_data_format",
71
+ "device",
72
+ "min_pixels",
73
+ "max_pixels",
74
+ "patch_size",
75
+ "temporal_patch_size",
76
+ "merge_size"
77
+ ],
78
+ "patch_size": 14,
79
+ "processor_class": "HCXVisionV2Processor",
80
+ "resample": 3,
81
+ "rescale_factor": 0.00392156862745098,
82
+ "size": {
83
+ "longest_edge": 12845056,
84
+ "shortest_edge": 3136
85
+ },
86
+ "size_divisor": null,
87
+ "temporal_patch_size": 2,
88
+ "video_processor_type": "Qwen2VLVideoProcessor"
89
+ }