Commit
·
3169f6c
0
Parent(s):
Init
Browse files- .gitattributes +8 -0
- LICENSE +122 -0
- README.md +459 -0
- chat_template.jinja +173 -0
- config.json +320 -0
- configuration_hyperclovax.py +228 -0
- configuration_vlm.py +169 -0
- cosyvoice.py +516 -0
- decoder/audio/NCCosybigvganDecoder.mar +3 -0
- decoder/audio/NCZSCosybigvganDecoder.mar +3 -0
- decoder/vision/model_index.json +25 -0
- decoder/vision/scheduler/scheduler_config.json +18 -0
- decoder/vision/token_embedder/config.json +7 -0
- decoder/vision/token_embedder/diffusion_pytorch_model.safetensors +3 -0
- decoder/vision/transformer/config.json +21 -0
- decoder/vision/transformer/diffusion_pytorch_model.safetensors +3 -0
- decoder/vision/transformer2/config.json +21 -0
- decoder/vision/transformer2/diffusion_pytorch_model.safetensors +3 -0
- decoder/vision/vae/config.json +38 -0
- decoder/vision/vae/diffusion_pytorch_model.safetensors +3 -0
- generation_config.json +6 -0
- mambamia_videoaudio_compressor.py +803 -0
- model-00001-of-00010.safetensors +3 -0
- model-00002-of-00010.safetensors +3 -0
- model-00003-of-00010.safetensors +3 -0
- model-00004-of-00010.safetensors +3 -0
- model-00005-of-00010.safetensors +3 -0
- model-00006-of-00010.safetensors +3 -0
- model-00007-of-00010.safetensors +3 -0
- model-00008-of-00010.safetensors +3 -0
- model-00009-of-00010.safetensors +3 -0
- model-00010-of-00010.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_hyperclovax.py +1866 -0
- modeling_vlm.py +0 -0
- patch_vuvlm.py +1085 -0
- preprocessor.py +0 -0
- preprocessor_config.json +32 -0
- processing_vlm.py +963 -0
- processor_config.json +6 -0
- special_tokens_map.json +30 -0
- ta_tok.py +379 -0
- tokenizer.json +3 -0
- tokenizer_config.json +3 -0
- video_preprocessor_config.json +89 -0
.gitattributes
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
decoder/audio/NCCosybigvganDecoder.mar filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
decoder/audio/NCZSCosybigvganDecoder.mar filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
tokenizer_config.json filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HyperCLOVA X SEED 8B Omni Model License Agreement
|
| 2 |
+
|
| 3 |
+
Model Release Date: December 29, 2025
|
| 4 |
+
|
| 5 |
+
This HyperCLOVA X SEED 8B Omni Model License Agreement (the “Agreement”) is a legal agreement between you and NAVER Corporation (“Naver Corp.”) and NAVER Cloud Corporation (“Naver Cloud Corp.”) (Naver Corp. and Naver Cloud Corp. are collectively referred to as “NAVER”) and governs your use of the Models that NAVER provides to You under this Agreement.
|
| 6 |
+
|
| 7 |
+
NAVER Corp., as the holder of the intellectual property of the Model, and its affiliate, NAVER Cloud Corp., as the exclusive business operator of HyperCLOVA X, enter into this Agreement with you. NAVER and you are each a “party” and collectively the “parties.”
|
| 8 |
+
|
| 9 |
+
By using, reproducing, modifying, distributing, performing or displaying any portion or element of the Model or Derivative Model, or otherwise accepting the terms of this Agreement, you agree to be bound by this Agreement. You represent to us that you are lawfully able to enter into contracts, and if you are entering into this Agreement for an entity, that you have legal authority to bind that entity.
|
| 10 |
+
|
| 11 |
+
1. Definitions.
|
| 12 |
+
|
| 13 |
+
1.1. "Affiliate” means any entity directly or indirectly controlling, controlled by or under common control with either party, where “control” means the possession, directly or indirectly, of the power to independently direct or cause the direction of the management and policies of an entity, whether through ownership of more than fifty percent (50%) of the stock or other equity interests entitled to vote for representation on its board of directors, or body performing similar functions, by contract or otherwise.
|
| 14 |
+
|
| 15 |
+
1.2. “Derivative Model” means all (i) modifications to the Model, (ii) works based on the Model, or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of the Model, to that model in order to cause that model to perform similarly to the Model, including distillation methods that use intermediate data representations or methods based on the generation of synthetic data Outputs by the Model for training that Model. For clarity, Outputs are not deemed Derivative Model.
|
| 16 |
+
|
| 17 |
+
1.3. “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 18 |
+
|
| 19 |
+
1.4. “Model” means the foundational large language models and software and algorithms, including machine-learning model code and trained model weights distributed by NAVER.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
1.5. “Output” means the information content output of the Model or a Derivative Model that results from operating or otherwise using the Model or Derivative Model.
|
| 23 |
+
|
| 24 |
+
2. Conditions for Use, License Grant and Restrictions
|
| 25 |
+
|
| 26 |
+
2.1. Conditions for Use. The Model and any Derivative Model are subject to the terms of this Agreement and govern your use. If You institute copyright or patent litigation against any entity (including a crossclaim or counterclaim in a lawsuit) alleging that the Model or Derivative Model constitutes direct or contributory copyright or patent infringement, then any license granted to you under this Agreement for that Model or Derivative Model will terminate as of the date such litigation is filed. NAVER may update this Agreement to comply with legal and regulatory requirements any time and You agree to either comply with any updated license or cease your copying, use, and distribution of the Model and any Derivative Model.
|
| 27 |
+
|
| 28 |
+
2.2. License Grant. Subject to the terms and conditions of this Agreement, NAVER hereby grants to you a non-exclusive, worldwide, non-transferable, revocable and royalty-free limited license under NAVER’s intellectual property or other rights owned by NAVER embodied in the Model to access, download, install, copy, use, reproduce, distribute, create derivative works of, and make modifications to the Model.
|
| 29 |
+
|
| 30 |
+
2.3. Prohibited Use Policy. NAVER is committed to ensuring safety trust, and transparency in the development and use of AI technologies. Accordingly, your use of the Model and any Derivative Models is subject to the following conditions:
|
| 31 |
+
(i) You must ensure that any product or service you develop, use, offer as a service, or distribute complies with all applicable laws and regulations, and is operated appropriately for the relevant industry or use case.
|
| 32 |
+
(ii) You must comply with the Acceptable Use Policy applicable to the Model and any Derivative Models, which is attached hereto as Addendum A and incorporated by reference into this Agreement.
|
| 33 |
+
(iii) NAVER expressly prohibits the use of its products or services for any purpose in violation of applicable law and regulation, including but not limited to:
|
| 34 |
+
(a) illegal surveillance,
|
| 35 |
+
(b) illegal collection or processing of biometric information without the consent of the subject which is required under applicable law, or
|
| 36 |
+
(c) illegal harassment, abuse, threatening or bullying of individuals or groups of individuals or intentionally misleading or deceiving others.
|
| 37 |
+
(iv) You must take reasonable measures to address unintended bias and to mitigate harm to others, including underrepresented or vulnerable groups.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
3. Redistribution.
|
| 41 |
+
|
| 42 |
+
3.1. You may reproduce, distribute or make available the Model or Derivative Models thereof, or a product or service (including another AI model) that contains any of them, if you meet all of the following conditions: you must (i) include the Prohibited Use Policy referenced in Section 2.3. as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of the Model or Derivative Model and you must provide notice to subsequence users you distribute to the Model or Derivative Models are subject to the use restrictions in Section 2.3., (ii) provide all third party recipients of the Model or Derivative Models a copy of this Agreement, (iii) cause any modified files to carry prominent notices stating that you modified the files; (iv) include the following attribution notice within a “Notice” text file distributed as part of such copies: “HyperCLOVA X SEED 8B Omni Model is licensed under the HyperCLOVA X SEED 8B Omni Model License Agreement, Copyright © NAVER Corp. All Rights Reserved.”, and (v) prominently display “Powered by HyperCLOVA X” on a related website, user interface, blogpost, about page, or product documentation. If you use the Model or any Outputs of the Model to create, train, fine tune, or otherwise improve an AI model, which is distributed or made available, you shall also include “HyperCLOVA X” at the beginning of any such AI model name.
|
| 43 |
+
3.2. You may add your own copyright statement to your modifications and, except as set forth in this Section, may provide additional or different license terms and conditions for use, reproduction, or distribution of your modifications, or for any such Derivative Models as a whole, provided your use, reproduction, and distribution of the Model or Derivative Models otherwise comply with the terms and conditions stated in this Agreement. Any additional or different terms and conditions you impose must not conflict with the terms of this Agreement.
|
| 44 |
+
|
| 45 |
+
4. Additional Commercial Terms. If (i) as of the Model Release Date, the monthly active users of the products or services made available by or for Licensee, or Licensee’s Affiliates, is greater than 10 million monthly active users in the preceding calendar month, or (ii) the Licensee or its Affiliate distributes or makes available any product or service, which is substantially similar to or directly competes with any product and service provided by NAVER, then the Licensee must request a license from NAVER. Such a license may be granted by NAVER at its sole discretion, and the Licensee is not authorized to exercise any rights under this Agreement unless and until NAVER expressly grants you such rights.
|
| 46 |
+
|
| 47 |
+
5. Generated Output. NAVER claims no rights in Outputs you generate using the Model. You and your use are solely responsible for Outputs and their subsequent uses.
|
| 48 |
+
|
| 49 |
+
6. DISCLAIMER OF WARRANTY. UNLESS REQUIRED BY APPLICABLE LAW, THE MODEL AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OR ANY KIND, AND NAVER DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE MODEL, DERIVATIVE MODELS, OUTPUTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE MODEL AND ANY OUTPUTS AND RESULTS AND YOUR EXERCISE OF PERMISSION UNDER THIS AGREEMENT.
|
| 50 |
+
|
| 51 |
+
7. LIMITATION OF LIABILITY. IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE, UNLESS REQUIRED BY APPLICABLE LAW (SUCH AS IN CASES OF DELIBERATE AND GROSSLY NEGLIGENT ACTS), WILL NAVER BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY, OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND, ARISING FROM OR RELATED TO THIS AGREEMENT, OR RESULTING FROM THE USE OR INABILITY TO USE THE MODEL, DERIVATIVE MODELS OR, OUTPUTS (INCLUDING, BUT NOT LIMITED TO, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGES, COMPUTER FAILURE OR MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES), EVEN IF NAVER HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
| 52 |
+
|
| 53 |
+
8. Indemnity. You will indemnify and hold harmless NAVER from and against any claim by any third party arising out of or related to your use or distribution of the Model, Derivative Model or Outputs.
|
| 54 |
+
|
| 55 |
+
9. Intellectual Property.
|
| 56 |
+
|
| 57 |
+
9.1. This Agreement does not grant permission to use the trade names, trademarks, service marks, or product names of NAVER, except as required for reasonable and customary use in describing the origin of the Model and reproducing the content of the “Notice” text file.
|
| 58 |
+
|
| 59 |
+
9.2. NAVER Corp. owns the Model and any Derivative Model created by NAVER Corp. Except as expressively granted in this Agreement, NAVER Corp. reserves all rights, interests and remedies in connection with the Model and Derivative Model created by NAVER Corp. and no other license or right is granted to you by implication, estoppel or otherwise. Subject to NAVER Corp.’s ownership of the Model and any Derivative Model made by or for NAVER Corp., with respect to any derivative works and modifications of the Model that are made by you, as between you and NAVER Corp., you are and will be the owner of such derivative works and modifications.
|
| 60 |
+
|
| 61 |
+
10. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Model and will continue in full force and effect until terminated in accordance with the terms and conditions of this Agreement. NAVER may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Model and Derivative Model. Section 5, 6, 7 and 10 shall survive the termination of this Agreement.
|
| 62 |
+
|
| 63 |
+
11. Governing Law and Jurisdiction.
|
| 64 |
+
|
| 65 |
+
11.1. This Agreement will be governed by and construed in accordance with the laws of the Republic of Korea, without regard to its conflicts of laws principles.
|
| 66 |
+
|
| 67 |
+
11.2. Any disputes, controversies, or claims arising out of or relating to this Agreement, including its existence, validity, interpretation, performance, breach, or termination, shall be referred to and finally resolved by arbitration administered by the Korean Commercial Arbitration Board (KCAB) in accordance with the International Arbitration Rules of the Korean Commercial Arbitration Board in force at the time of the commencement of the arbitration. The seat of arbitration shall be Seoul, Republic of Korea. The tribunal shall consist of one arbitrator. The language of the arbitration shall be English. Either party may seek interim or provisional relief from a court of competent jurisdiction and doing so shall not be considered a waiver of any provision in this section. The arbitral tribunal also has the authority to issue orders for interim or provisional relief.
|
| 68 |
+
|
| 69 |
+
12. Modifications. NAVER reserves the right to modify or amend this Agreement at any time, in its sole discretion. Any modifications will be effective upon posting the updated Agreement on our website or through other means of communication. You are responsible for reviewing the Agreement periodically for changes.
|
| 70 |
+
|
| 71 |
+
13. No Waiver. NAVER will not be treated as having waived any rights by not exercising (or delaying the exercise of) any rights under this Agreement.
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
Addendum A – Acceptable Use Policy
|
| 76 |
+
|
| 77 |
+
NAVER is committed to promoting safe and responsible use of its AI technologies, including the HyperCLOVA X SEED 8B Omni Model (the “Model”). By accessing or using the Model and Derivative Model (Defined in the Model License Agreement) (the Model and Derivative Model are collectively referred to as the “Models”), you agree to this Acceptable Use Policy (“Policy”).
|
| 78 |
+
|
| 79 |
+
We want everyone to use the Models safely, legally, and ethically. You agree that you will not use, or allow others to use, the Models to:
|
| 80 |
+
|
| 81 |
+
1. Violate applicable laws or the rights of others, including by:
|
| 82 |
+
a. Engaging in, promoting, contributing to, encouraging, planning, inciting, or furthering illegal or unlawful activity or content, such as:
|
| 83 |
+
Violence or terrorism
|
| 84 |
+
Exploitation or harm to children, including the creation or dissemination of child exploitative content
|
| 85 |
+
Human trafficking, exploitation, or sexual violence
|
| 86 |
+
The unlawful distribution of obscene or harmful material to minors, or failure to apply legally required age restrictions
|
| 87 |
+
Sexual solicitation or sexually exploitative behavior
|
| 88 |
+
Any other criminal activity
|
| 89 |
+
b. Engaging in, promoting, inciting, or facilitating the harassment, abuse, threatening, or bullying of individuals or groups
|
| 90 |
+
c. Engaging in, promoting, inciting, or facilitating discrimination or other unlawful or harmful conduct in the provision of employment, credit, housing, or access to essential goods and services
|
| 91 |
+
d. Providing unauthorized or unlicensed professional services, including but not limited to financial, legal, medical/health, or related services
|
| 92 |
+
e. Collecting, processing, disclosing, generating, or inferring private or sensitive personal information, including identity, health, or demographic data, unless lawfully permitted under applicable laws
|
| 93 |
+
f. Infringing, misappropriating, or otherwise violating third-party rights, including through the generation or use of outputs derived from the Models
|
| 94 |
+
g. Creating, generating, or facilitating malicious code, malware, or computer viruses, or interfering with the functioning, security, or integrity of a website, application, or system
|
| 95 |
+
h. Intentionally bypassing or disabling usage restrictions, safety measures, or access controls imposed by NAVER
|
| 96 |
+
|
| 97 |
+
2. Engage in or promote use cases that may pose a risk of death, bodily harm, or significant safety hazard to individuals, including use of the Models in connection with:
|
| 98 |
+
a. Military, warfare, nuclear technology or espionage
|
| 99 |
+
b. The development or distribution of firearms or illegal weapons
|
| 100 |
+
c. Illegal drugs or regulated controlled substances
|
| 101 |
+
d. Operation of critical infrastructure, transportation systems, or heavy machinery
|
| 102 |
+
e. Content promoting self-harm, including suicide, or eating disorders
|
| 103 |
+
f. Any other use intended to incite or cause physical harm
|
| 104 |
+
|
| 105 |
+
3. Intentionally deceive or mislead others, including by:
|
| 106 |
+
a. Generating, promoting, or disseminating fraudulent or misleading content
|
| 107 |
+
b. Creating or sharing defamatory content
|
| 108 |
+
c. Generating or distributing spam
|
| 109 |
+
d. Impersonating another individual or entity without proper authorization
|
| 110 |
+
e. Representing Model output as human-generated
|
| 111 |
+
f. Generating or enabling fake online engagement, such as fake reviews or fake users
|
| 112 |
+
|
| 113 |
+
4. Fail to disclose to end users any known risks or limitations of an AI system that incorporates the Models.
|
| 114 |
+
|
| 115 |
+
5. Use the Models in conjunction with third-party tools, models, or software designed to generate unlawful content or conduct, or falsely represent outputs from such tools as associated with NAVER or HyperCLOVA X.
|
| 116 |
+
|
| 117 |
+
If you become aware of a violation of this Policy, a bug, or any behavior that could result in a breach of this Policy, please report it to us:
|
| 118 |
+
|
| 119 |
+
Reporting risky outputs: [email protected]
|
| 120 |
+
Reporting policy violations or unauthorized use: [email protected]
|
| 121 |
+
|
| 122 |
+
|
README.md
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: other
|
| 3 |
+
license_name: hyperclovax
|
| 4 |
+
license_link: LICENSE
|
| 5 |
+
library_name: transformers
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+

|
| 9 |
+
|
| 10 |
+
# Overview
|
| 11 |
+
HyperCLOVA X SEED 8B Omni is a unified multimodal model that brings text, vision, and speech together, based on an auto-regressive Transformer architecture, enabling consistent multimodal understanding and generation. SEED 8B Omni aligns textual, visual, and audio representations in a shared semantic space and supports bidirectional interactions across modalities, including established text capabilities as well as vision–language QA, text-to-image generation and editing, speech recognition and translation, and text-to-speech, within a 32K context window. As an early pathfinding milestone of HyperCLOVA X toward **Any-to-Any-Korean-First** intelligence, SEED 8B Omni serves as a practical exploration of unified multimodal modeling and provides a reference point for future development and scaling.
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Basic Information
|
| 16 |
+
|
| 17 |
+
- **Architecture** : Transformer-based omni-model architecture (Dense Model)
|
| 18 |
+
- **Parameters** : 8B
|
| 19 |
+
- **Input Format**: Text/Image/Video/Audio(Speech)
|
| 20 |
+
- **Output Format**: Text/Image/Audio(Speech)
|
| 21 |
+
- **Context Length** : 32K
|
| 22 |
+
- **Knowledge Cutoff**: May 2025
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
# Benchmarks
|
| 27 |
+

|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
- **Text-to-Text** : MMLU-Pro, GSM8K, KMMLU-Pro, HAERAE 1.0
|
| 31 |
+
- **Vision-to-Text** :SEED-IMG, AI2D, K-MMBench
|
| 32 |
+
- **Text-to-Vision**: GenEval, ImgEdit
|
| 33 |
+
- **Audio-to-Text**: Librispeech, Ksponspeech
|
| 34 |
+
- **Audio-to-Audio**:Fleurs en2ko, Fleurs ko2en
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
# Examples
|
| 39 |
+
## Text-to-Image Generation
|
| 40 |
+

|
| 41 |
+
## Text-based Image Editing
|
| 42 |
+

|
| 43 |
+

|
| 44 |
+

|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
# Inference
|
| 49 |
+
|
| 50 |
+
We provide [OmniServe](https://github.com/NAVER-Cloud-HyperCLOVA-X/OmniServe), a production-ready multimodal inference system with OpenAI-compatible API.
|
| 51 |
+
|
| 52 |
+
## Capabilities
|
| 53 |
+
|
| 54 |
+
- **Inputs**: Text, Image, Audio, Video
|
| 55 |
+
- **Outputs**: Text, Image, Audio (no video generation)
|
| 56 |
+
|
| 57 |
+
## Requirements
|
| 58 |
+
|
| 59 |
+
- 4x NVIDIA A100 80GB
|
| 60 |
+
- Docker & Docker Compose
|
| 61 |
+
- NVIDIA Driver 525+, CUDA 12.1+
|
| 62 |
+
- S3-compatible storage (for image/audio output)
|
| 63 |
+
|
| 64 |
+
## Installation
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
# Clone OmniServe
|
| 68 |
+
git clone https://github.com/NAVER-Cloud-HyperCLOVA-X/OmniServe.git
|
| 69 |
+
cd OmniServe
|
| 70 |
+
|
| 71 |
+
# Install dependencies
|
| 72 |
+
pip install huggingface_hub safetensors torch openai easydict
|
| 73 |
+
|
| 74 |
+
# Download model (~16GB)
|
| 75 |
+
huggingface-cli download naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B \
|
| 76 |
+
--local-dir ./models/HyperCLOVAX-SEED-Omni-8B
|
| 77 |
+
|
| 78 |
+
# Convert model to component format
|
| 79 |
+
python convert_model.py \
|
| 80 |
+
--input ./models/HyperCLOVAX-SEED-Omni-8B \
|
| 81 |
+
--output ./track_b \
|
| 82 |
+
--track b
|
| 83 |
+
|
| 84 |
+
# Configure environment
|
| 85 |
+
cp .env.example .env
|
| 86 |
+
# Edit .env with model paths and S3 credentials
|
| 87 |
+
|
| 88 |
+
# Build and run (Track B only - OMNI model)
|
| 89 |
+
docker compose --profile track-b build
|
| 90 |
+
docker compose --profile track-b up -d
|
| 91 |
+
|
| 92 |
+
# Wait for model loading (~5 minutes)
|
| 93 |
+
docker compose logs -f omni
|
| 94 |
+
|
| 95 |
+
# Note: To run both VLM and OMNI models together:
|
| 96 |
+
# docker compose --profile track-a --profile track-b up -d
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## Basic Usage
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
from openai import OpenAI
|
| 103 |
+
|
| 104 |
+
client = OpenAI(
|
| 105 |
+
base_url="http://localhost:8000/b/v1",
|
| 106 |
+
api_key="not-needed"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Image understanding
|
| 110 |
+
response = client.chat.completions.create(
|
| 111 |
+
model="track_b_model",
|
| 112 |
+
messages=[
|
| 113 |
+
{
|
| 114 |
+
"role": "user",
|
| 115 |
+
"content": [
|
| 116 |
+
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
|
| 117 |
+
{"type": "text", "text": "What is in this image?"}
|
| 118 |
+
]
|
| 119 |
+
}
|
| 120 |
+
],
|
| 121 |
+
max_tokens=256,
|
| 122 |
+
extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
print(response.choices[0].message.content)
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## More Examples
|
| 129 |
+
|
| 130 |
+
<details>
|
| 131 |
+
<summary>Text to Image</summary>
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
import json
|
| 135 |
+
|
| 136 |
+
SYSTEM_PROMPT = """You are an AI assistant that generates images. When asked to draw or create an image, you MUST use the t2i_model_generation tool to generate the image. Always respond by calling the tool."""
|
| 137 |
+
|
| 138 |
+
response = client.chat.completions.create(
|
| 139 |
+
model="track_b_model",
|
| 140 |
+
messages=[
|
| 141 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 142 |
+
{"role": "user", "content": "Draw a sunset over mountains"}
|
| 143 |
+
],
|
| 144 |
+
tools=[{
|
| 145 |
+
"type": "function",
|
| 146 |
+
"function": {
|
| 147 |
+
"name": "t2i_model_generation",
|
| 148 |
+
"description": "Generates an RGB image based on the provided discrete image representation.",
|
| 149 |
+
"parameters": {
|
| 150 |
+
"type": "object",
|
| 151 |
+
"required": ["discrete_image_token"],
|
| 152 |
+
"properties": {
|
| 153 |
+
"discrete_image_token": {
|
| 154 |
+
"type": "string",
|
| 155 |
+
"description": "A serialized string of discrete vision tokens, encapsulated by special tokens. The format must be strictly followed: <|discrete_image_start|><|vision_ratio_4:3|><|vision_token|><|visionaaaaa|><|visionbbbbb|>... <|visionzzzzz|><|vision_eol|><|vision_eof|><|discrete_image_end|>.",
|
| 156 |
+
"minLength": 1
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
}],
|
| 162 |
+
max_tokens=7000,
|
| 163 |
+
extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
if response.choices[0].message.tool_calls:
|
| 167 |
+
args = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
|
| 168 |
+
print(f"Generated image: {args['discrete_image_token']}")
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
</details>
|
| 172 |
+
|
| 173 |
+
<details>
|
| 174 |
+
<summary>Text to Audio</summary>
|
| 175 |
+
|
| 176 |
+
```python
|
| 177 |
+
import base64
|
| 178 |
+
|
| 179 |
+
# Prompt should explicitly request speech/audio output
|
| 180 |
+
response = client.chat.completions.create(
|
| 181 |
+
model="track_b_model",
|
| 182 |
+
messages=[{
|
| 183 |
+
"role": "user",
|
| 184 |
+
"content": "Read this text aloud in a cheerful female voice:\nHello! How are you today?"
|
| 185 |
+
}],
|
| 186 |
+
max_tokens=1000,
|
| 187 |
+
extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
if response.choices[0].message.audio:
|
| 191 |
+
audio_url = base64.b64decode(response.choices[0].message.audio.data).decode()
|
| 192 |
+
print(f"Generated audio: {audio_url}")
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
</details>
|
| 196 |
+
|
| 197 |
+
<details>
|
| 198 |
+
<summary>Audio Input</summary>
|
| 199 |
+
|
| 200 |
+
```python
|
| 201 |
+
import base64
|
| 202 |
+
|
| 203 |
+
audio_url = "https://example.com/audio.mp3"
|
| 204 |
+
audio_data = base64.b64encode(audio_url.encode()).decode()
|
| 205 |
+
|
| 206 |
+
response = client.chat.completions.create(
|
| 207 |
+
model="track_b_model",
|
| 208 |
+
messages=[
|
| 209 |
+
{
|
| 210 |
+
"role": "user",
|
| 211 |
+
"content": [
|
| 212 |
+
{"type": "input_audio", "input_audio": {"data": audio_data, "format": "mp3"}},
|
| 213 |
+
{"type": "text", "text": "What is being said?"}
|
| 214 |
+
]
|
| 215 |
+
}
|
| 216 |
+
],
|
| 217 |
+
max_tokens=256,
|
| 218 |
+
extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
print(response.choices[0].message.content)
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
</details>
|
| 225 |
+
|
| 226 |
+
<details>
|
| 227 |
+
<summary>Video Input</summary>
|
| 228 |
+
|
| 229 |
+
```python
|
| 230 |
+
response = client.chat.completions.create(
|
| 231 |
+
model="track_b_model",
|
| 232 |
+
messages=[
|
| 233 |
+
{
|
| 234 |
+
"role": "user",
|
| 235 |
+
"content": [
|
| 236 |
+
{"type": "image_url", "image_url": {"url": "https://example.com/video.mp4"}},
|
| 237 |
+
{"type": "text", "text": "Describe this video."}
|
| 238 |
+
]
|
| 239 |
+
}
|
| 240 |
+
],
|
| 241 |
+
max_tokens=512,
|
| 242 |
+
extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
print(response.choices[0].message.content)
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
</details>
|
| 249 |
+
|
| 250 |
+
<details>
|
| 251 |
+
<summary>Image to Image</summary>
|
| 252 |
+
|
| 253 |
+
```python
|
| 254 |
+
import json
|
| 255 |
+
|
| 256 |
+
SYSTEM_PROMPT = """You are an AI assistant that transforms images. When asked to transform, edit, or stylize an image, you MUST use the t2i_model_generation tool to generate the new image. Always respond by calling the tool."""
|
| 257 |
+
|
| 258 |
+
response = client.chat.completions.create(
|
| 259 |
+
model="track_b_model",
|
| 260 |
+
messages=[
|
| 261 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 262 |
+
{
|
| 263 |
+
"role": "user",
|
| 264 |
+
"content": [
|
| 265 |
+
{"type": "image_url", "image_url": {"url": "https://example.com/photo.jpg"}},
|
| 266 |
+
{"type": "text", "text": "Transform to watercolor style"}
|
| 267 |
+
]
|
| 268 |
+
}
|
| 269 |
+
],
|
| 270 |
+
tools=[{
|
| 271 |
+
"type": "function",
|
| 272 |
+
"function": {
|
| 273 |
+
"name": "t2i_model_generation",
|
| 274 |
+
"description": "Generates an RGB image based on the provided discrete image representation.",
|
| 275 |
+
"parameters": {
|
| 276 |
+
"type": "object",
|
| 277 |
+
"required": ["discrete_image_token"],
|
| 278 |
+
"properties": {
|
| 279 |
+
"discrete_image_token": {
|
| 280 |
+
"type": "string",
|
| 281 |
+
"description": "A serialized string of discrete vision tokens, encapsulated by special tokens. The format must be strictly followed: <|discrete_image_start|><|vision_ratio_4:3|><|vision_token|><|visionaaaaa|><|visionbbbbb|>... <|visionzzzzz|><|vision_eol|><|vision_eof|><|discrete_image_end|>.",
|
| 282 |
+
"minLength": 1
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
}],
|
| 288 |
+
max_tokens=7000,
|
| 289 |
+
extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if response.choices[0].message.tool_calls:
|
| 293 |
+
args = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
|
| 294 |
+
print(f"Generated image: {args['discrete_image_token']}")
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
</details>
|
| 298 |
+
|
| 299 |
+
<details>
|
| 300 |
+
<summary>Audio to Audio</summary>
|
| 301 |
+
|
| 302 |
+
```python
|
| 303 |
+
import base64
|
| 304 |
+
|
| 305 |
+
# Input audio (URL encoded as base64)
|
| 306 |
+
audio_url = "https://example.com/input.mp3"
|
| 307 |
+
audio_data = base64.b64encode(audio_url.encode()).decode()
|
| 308 |
+
|
| 309 |
+
response = client.chat.completions.create(
|
| 310 |
+
model="track_b_model",
|
| 311 |
+
messages=[
|
| 312 |
+
{
|
| 313 |
+
"role": "user",
|
| 314 |
+
"content": [
|
| 315 |
+
{"type": "input_audio", "input_audio": {"data": audio_data, "format": "mp3"}},
|
| 316 |
+
{"type": "text", "text": "Listen to this and respond with speech"}
|
| 317 |
+
]
|
| 318 |
+
}
|
| 319 |
+
],
|
| 320 |
+
max_tokens=2000,
|
| 321 |
+
extra_body={"chat_template_kwargs": {"skip_reasoning": True}}
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if response.choices[0].message.audio:
|
| 325 |
+
audio_url = base64.b64decode(response.choices[0].message.audio.data).decode()
|
| 326 |
+
print(f"Generated audio: {audio_url}")
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
</details>
|
| 330 |
+
|
| 331 |
+
<details>
|
| 332 |
+
<summary>Using curl</summary>
|
| 333 |
+
|
| 334 |
+
```bash
|
| 335 |
+
# Image understanding
|
| 336 |
+
curl -X POST http://localhost:8000/b/v1/chat/completions \
|
| 337 |
+
-H "Content-Type: application/json" \
|
| 338 |
+
-d '{
|
| 339 |
+
"model": "track_b_model",
|
| 340 |
+
"messages": [{"role": "user", "content": [
|
| 341 |
+
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
|
| 342 |
+
{"type": "text", "text": "Describe this image."}
|
| 343 |
+
]}],
|
| 344 |
+
"max_tokens": 256,
|
| 345 |
+
"extra_body": {"chat_template_kwargs": {"skip_reasoning": true}}
|
| 346 |
+
}'
|
| 347 |
+
|
| 348 |
+
# Text to audio
|
| 349 |
+
curl -X POST http://localhost:8000/b/v1/chat/completions \
|
| 350 |
+
-H "Content-Type: application/json" \
|
| 351 |
+
-d '{
|
| 352 |
+
"model": "track_b_model",
|
| 353 |
+
"messages": [{"role": "user", "content": "Say hello"}],
|
| 354 |
+
"max_tokens": 1000,
|
| 355 |
+
"extra_body": {"chat_template_kwargs": {"skip_reasoning": true}}
|
| 356 |
+
}'
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
</details>
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
## Architecture
|
| 363 |
+
|
| 364 |
+
```
|
| 365 |
+
User Request
|
| 366 |
+
(Image/Audio/Video/Text)
|
| 367 |
+
│
|
| 368 |
+
▼
|
| 369 |
+
┌─────────────────────────────────────────────────────────────────────────┐
|
| 370 |
+
│ OmniServe │
|
| 371 |
+
│ POST /b/v1/chat/completions │
|
| 372 |
+
│ │
|
| 373 |
+
│ ┌──────────────────────────────────────────────────────────────────┐ │
|
| 374 |
+
│ │ [1] INPUT ENCODING │ │
|
| 375 |
+
│ │ │ │
|
| 376 |
+
│ │ ┌─────────────────┐ ┌─────────────────┐ │ │
|
| 377 |
+
│ │ │ Vision Encoder │ │ Audio Encoder │ │ │
|
| 378 |
+
│ │ └────────┬────────┘ └────────┬────────┘ │ │
|
| 379 |
+
│ │ │ │ │ │
|
| 380 |
+
│ │ └────────────┬────────────────────┘ │ │
|
| 381 |
+
│ │ │ embeddings │ │
|
| 382 |
+
│ └──────────────────────────┼───────────────────────────────────────┘ │
|
| 383 |
+
│ ▼ │
|
| 384 |
+
│ ┌──────────────┐ │
|
| 385 |
+
│ │ LLM (8B) │◀──── text │
|
| 386 |
+
│ └──────┬───────┘ │
|
| 387 |
+
│ │ │
|
| 388 |
+
│ ┌─────────────────────────┼────────────────────────────────────────┐ │
|
| 389 |
+
│ │ [2] OUTPUT DECODING │ │
|
| 390 |
+
│ │ │ │ │
|
| 391 |
+
│ │ ┌──────────────┼──────────────┐ │ │
|
| 392 |
+
│ │ ▼ ▼ ▼ │ │
|
| 393 |
+
│ │ ┌───────────┐ ┌───────────┐ ┌───────────┐ │ │
|
| 394 |
+
│ │ │ Text │ │ Vision │ │ Audio │ │ │
|
| 395 |
+
│ │ │ │ │ Decoder │ │ Decoder │ │ │
|
| 396 |
+
│ │ └───────────┘ └─────┬─────┘ └─────┬─────┘ │ │
|
| 397 |
+
│ │ │ │ │ │
|
| 398 |
+
│ │ ▼ ▼ │ │
|
| 399 |
+
│ │ Image URL Audio URL │ │
|
| 400 |
+
│ │ (S3) (S3) │ │
|
| 401 |
+
│ └──────────────────────────────────────────────────────────────────┘ │
|
| 402 |
+
│ │
|
| 403 |
+
└─────────────────────────────────────────────────────────────────────────┘
|
| 404 |
+
│
|
| 405 |
+
▼
|
| 406 |
+
Response
|
| 407 |
+
(Text / Image URL / Audio URL)
|
| 408 |
+
```
|
| 409 |
+
|
| 410 |
+
## Hardware Requirements
|
| 411 |
+
|
| 412 |
+
| Component | GPU | VRAM |
|
| 413 |
+
|-----------|-----|------|
|
| 414 |
+
| Vision Encoder | 1x | ~8GB |
|
| 415 |
+
| Audio Encoder | (shared) | ~4GB |
|
| 416 |
+
| LLM (8B) | 1x | ~16GB |
|
| 417 |
+
| Vision Decoder | 1x | ~16GB |
|
| 418 |
+
| Audio Decoder | (shared) | ~4GB |
|
| 419 |
+
| **Total** | **3x** | **~48GB** |
|
| 420 |
+
|
| 421 |
+
## Key Parameters
|
| 422 |
+
|
| 423 |
+
| Parameter | Description | Default |
|
| 424 |
+
|-----------|-------------|---------|
|
| 425 |
+
| `chat_template_kwargs.skip_reasoning` | Skip reasoning | `true` |
|
| 426 |
+
| `max_tokens` | Max output tokens | - |
|
| 427 |
+
| `temperature` | Sampling temperature | 0.7 |
|
| 428 |
+
| `tools` | Required for image generation | - |
|
| 429 |
+
|
| 430 |
+
## S3 Configuration
|
| 431 |
+
|
| 432 |
+
Required for image/audio generation:
|
| 433 |
+
|
| 434 |
+
```bash
|
| 435 |
+
NCP_S3_ENDPOINT=https://your-s3-endpoint.com
|
| 436 |
+
NCP_S3_ACCESS_KEY=your-access-key
|
| 437 |
+
NCP_S3_SECRET_KEY=your-secret-key
|
| 438 |
+
NCP_S3_BUCKET_NAME=your-bucket-name
|
| 439 |
+
```
|
| 440 |
+
|
| 441 |
+
For more details, see [OmniServe documentation](https://github.com/NAVER-Cloud-HyperCLOVA-X/OmniServe).
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
---
|
| 445 |
+
|
| 446 |
+
# Citation
|
| 447 |
+
TBU (Technical Report)
|
| 448 |
+
|
| 449 |
+
---
|
| 450 |
+
|
| 451 |
+
# Questions
|
| 452 |
+
For any other questions, please feel free to contact us at [email protected].
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
---
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# License
|
| 459 |
+
The model is licensed under [HyperCLOVA X SEED 8B Omni Model License Agreement](./LICENSE)
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- set ns_img = namespace(count=0) %}
|
| 2 |
+
{%- set ns_aud = namespace(count=0) %}
|
| 3 |
+
{%- set ns_vid = namespace(count=0) %}
|
| 4 |
+
{%- if tools %}
|
| 5 |
+
{{- '<|im_start|>system\n' }}
|
| 6 |
+
{%- if messages[0].role == 'system' %}
|
| 7 |
+
{%- if messages[0].content is string %}
|
| 8 |
+
{{- messages[0].content + '\n\n' }}
|
| 9 |
+
{%- elif messages[0].content is sequence %}
|
| 10 |
+
{%- for part in messages[0].content %}
|
| 11 |
+
{%- if part.type == 'text' %}
|
| 12 |
+
{{- part.text }}
|
| 13 |
+
{%- endif %}
|
| 14 |
+
{%- endfor %}
|
| 15 |
+
{{- '\n\n' }}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- endif %}
|
| 18 |
+
{{- '# Tools\n\n' }}
|
| 19 |
+
{{- 'You may call one or more functions to assist with the user query.\n\n' }}
|
| 20 |
+
{{- 'You are provided with function signatures within <tools></tools> XML tags:\n' }}
|
| 21 |
+
{{- '<tools>\n' }}
|
| 22 |
+
{%- for tool in tools %}
|
| 23 |
+
{{- tool | tojson }}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{{- '\n</tools>\n\n' }}
|
| 26 |
+
{{- 'For each function call, output the function name and arguments within the following XML format:\n' }}
|
| 27 |
+
{{- '<tool_call>{function-name}\n' }}
|
| 28 |
+
{{- '<arg_key>{arg-key-1}</arg_key>\n' }}
|
| 29 |
+
{{- '<arg_value>{arg-value-1}</arg_value>\n' }}
|
| 30 |
+
{{- '<arg_key>{arg-key-2}</arg_key>\n' }}
|
| 31 |
+
{{- '<arg_value>{arg-value-2}</arg_value>\n' }}
|
| 32 |
+
{{- '...\n' }}
|
| 33 |
+
{{- '</tool_call><|im_end|>\n' }}
|
| 34 |
+
{%- else %}
|
| 35 |
+
{%- if messages[0].role == 'system' %}
|
| 36 |
+
{{- '<|im_start|>system\n' }}
|
| 37 |
+
{%- if messages[0].content is string %}
|
| 38 |
+
{{- messages[0].content }}
|
| 39 |
+
{%- elif messages[0].content is sequence %}
|
| 40 |
+
{%- for part in messages[0].content %}
|
| 41 |
+
{%- if part.type == 'text' %}
|
| 42 |
+
{{- part.text }}
|
| 43 |
+
{%- endif %}
|
| 44 |
+
{%- endfor %}
|
| 45 |
+
{%- endif %}
|
| 46 |
+
{{- '<|im_end|>\n' }}
|
| 47 |
+
{%- endif %}
|
| 48 |
+
{%- endif %}
|
| 49 |
+
{%- set ns = namespace(last_user_index=-1) %}
|
| 50 |
+
{%- for m in messages %}
|
| 51 |
+
{%- if m.role == 'user' %}
|
| 52 |
+
{%- set ns.last_user_index = loop.index0 %}
|
| 53 |
+
{%- endif %}
|
| 54 |
+
{%- endfor %}
|
| 55 |
+
{%- for message in messages %}
|
| 56 |
+
{%- set content = message.content %}
|
| 57 |
+
{%- if (message.role == 'system' and not loop.first) %}
|
| 58 |
+
{{- '<|im_start|>' + message.role + '\n' }}
|
| 59 |
+
{%- if content is string %}
|
| 60 |
+
{{- content }}
|
| 61 |
+
{%- elif content is sequence %}
|
| 62 |
+
{%- for part in content %}
|
| 63 |
+
{%- if part.type == 'text' %}
|
| 64 |
+
{{- part.text }}
|
| 65 |
+
{%- endif %}
|
| 66 |
+
{%- endfor %}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{{- '<|im_end|>' + '\n' }}
|
| 69 |
+
{%- elif message.role == 'user' %}
|
| 70 |
+
{{- '<|im_start|>user\n' }}
|
| 71 |
+
{%- if message['content'] is string %}
|
| 72 |
+
{{- message['content'] + '<|im_end|>\n' }}
|
| 73 |
+
{%- elif message['content'] is sequence %}
|
| 74 |
+
{%- for content in message['content'] %}
|
| 75 |
+
{%- if not loop.first %}
|
| 76 |
+
{{- '\n' }}
|
| 77 |
+
{%- endif %}
|
| 78 |
+
{%- if content['type'] == 'image_url' %}
|
| 79 |
+
{%- set media_url = content.get('image_url', {}).get('url', '') %}
|
| 80 |
+
{%- set url_lower = media_url.lower() %}
|
| 81 |
+
{%- set image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".svg"] %}
|
| 82 |
+
{%- set video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"] %}
|
| 83 |
+
{%- set ns_check = namespace(is_video=False) %}
|
| 84 |
+
{%- for ext in video_extensions %}
|
| 85 |
+
{%- if url_lower.endswith(ext) %}
|
| 86 |
+
{%- set ns_check.is_video = True %}
|
| 87 |
+
{%- endif %}
|
| 88 |
+
{%- endfor %}
|
| 89 |
+
{%- if ns_check.is_video %}
|
| 90 |
+
{%- set video_id = 'video_%02d' % ns_vid.count %}
|
| 91 |
+
{%- set ns_vid.count = ns_vid.count + 1 %}
|
| 92 |
+
{{- '<|mime_start|>{"id": "' + video_id + '", "type": "video/mp4", "filename": "video.mp4"}<|mime_end|>\n' }}
|
| 93 |
+
{{- '<|video_aux_start|>다음 중 video_duration은 비디오 길이 정보입니다. 참고하여 답변하세요. {"video_duration": "<|video_meta_duration|>"}<|video_aux_end|>\n' }}
|
| 94 |
+
{{- '<|video_start|><|VIDEO_PAD|><|video_end|>' }}
|
| 95 |
+
{%- else %}
|
| 96 |
+
{%- set image_id = 'image_%02d' % ns_img.count %}
|
| 97 |
+
{%- set ns_img.count = ns_img.count + 1 %}
|
| 98 |
+
{{- '<|mime_start|>{"id": "' + image_id + '", "type": "image/jpeg", "filename": "image.jpg"}<|mime_end|>\n' }}
|
| 99 |
+
{{- '<|discrete_image_start|><|DISCRETE_IMAGE_PAD|><|discrete_image_end|>\n' }}
|
| 100 |
+
{{- '<|image_start|><|IMAGE_PAD|><|image_end|>' }}
|
| 101 |
+
{%- endif %}
|
| 102 |
+
{%- elif content['type'] == 'input_audio' %}
|
| 103 |
+
{%- set audio_id = 'audio_%02d' % ns_aud.count %}
|
| 104 |
+
{%- set ns_aud.count = ns_aud.count + 1 %}
|
| 105 |
+
{%- set input_audio = content.get('input_audio', {}) %}
|
| 106 |
+
{{- '<|mime_start|>{"id": "' + audio_id + '", "type": "audio/mpeg", "filename": "user_query.wav"}<|mime_end|>\n' }}
|
| 107 |
+
{{- '<|audio_aux_start|>다음 중 audio_duration은 오디오 길이 정보입니다. 참고하여 답변하세요. {"audio_duration": "<|audio_meta_duration|>"}<|audio_aux_end|>\n'}}
|
| 108 |
+
{{- '<|discrete_audio_start|><|DISCRETE_AUDIO_PAD|><|discrete_audio_end|>\n'}}
|
| 109 |
+
{{- '<|audio_start|><|AUDIO_PAD|><|audio_end|>'}}
|
| 110 |
+
{%- elif content['type'] == 'text' %}
|
| 111 |
+
{{- content['text'] }}
|
| 112 |
+
{%- endif %}
|
| 113 |
+
{%- endfor %}
|
| 114 |
+
{{- '<|im_end|>\n'}}
|
| 115 |
+
{%- endif %}
|
| 116 |
+
{%- elif message.role == 'assistant' %}
|
| 117 |
+
{%- set reasoning_content = '' %}
|
| 118 |
+
{%- if message.get('reasoning_content') is string %}
|
| 119 |
+
{%- set reasoning_content = message.get('reasoning_content') %}
|
| 120 |
+
{%- else %}
|
| 121 |
+
{%- if '</think>' in content %}
|
| 122 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 123 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 124 |
+
{%- endif %}
|
| 125 |
+
{%- endif %}
|
| 126 |
+
{%- if loop.index0 > ns.last_user_index %}
|
| 127 |
+
{%- if loop.last or reasoning_content %}
|
| 128 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' }}
|
| 129 |
+
{%- else %}
|
| 130 |
+
{{- '<|im_start|>' + message.role + '\n' }}
|
| 131 |
+
{%- endif %}
|
| 132 |
+
{%- else %}
|
| 133 |
+
{{- '<|im_start|>' + message.role + '\n' }}
|
| 134 |
+
{%- endif %}
|
| 135 |
+
{{- content }}
|
| 136 |
+
{%- if message.get('tool_calls') %}
|
| 137 |
+
{%- for tool_call in message.get('tool_calls', []) %}
|
| 138 |
+
{%- if not loop.first or content %}
|
| 139 |
+
{{- '\n' }}
|
| 140 |
+
{%- endif %}
|
| 141 |
+
{%- if tool_call.function %}
|
| 142 |
+
{%- set tool_call = tool_call.function %}
|
| 143 |
+
{%- endif %}
|
| 144 |
+
{{- '<tool_call>' + tool_call.name + '\n' }}
|
| 145 |
+
{%- set _args = tool_call.arguments %}
|
| 146 |
+
{%- for k, v in _args.items() %}
|
| 147 |
+
{{- '<arg_key>' + k + '</arg_key>\n' }}
|
| 148 |
+
{{- '<arg_value>' + (v | tojson if v is not string else v) + '</arg_value>\n' }}
|
| 149 |
+
{%- endfor %}
|
| 150 |
+
{{- '</tool_call>' }}
|
| 151 |
+
{%- endfor %}
|
| 152 |
+
{%- endif %}
|
| 153 |
+
{{- '<|im_end|>\n' }}
|
| 154 |
+
{%- elif message.role == 'tool' %}
|
| 155 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != 'tool') %}
|
| 156 |
+
{{- '<|im_start|>tool' }}
|
| 157 |
+
{%- endif %}
|
| 158 |
+
{{- '\n<tool_response>' + message.get('name', '') + '\n' }}
|
| 159 |
+
{%- if message['content'] is string %}
|
| 160 |
+
{{- content }}
|
| 161 |
+
{%- endif %}
|
| 162 |
+
{{- '\n</tool_response>' }}
|
| 163 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != 'tool') %}
|
| 164 |
+
{{- '<|im_end|>\n' }}
|
| 165 |
+
{%- endif %}
|
| 166 |
+
{%- endif %}
|
| 167 |
+
{%- endfor %}
|
| 168 |
+
{%- if add_generation_prompt %}
|
| 169 |
+
{{- '<|im_start|>assistant\n<think>\n' }}
|
| 170 |
+
{%- if skip_reasoning is defined and skip_reasoning is true %}
|
| 171 |
+
{{- '\n</think>\n\n' }}
|
| 172 |
+
{%- endif %}
|
| 173 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"anyres": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"HCXVisionV2ForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"audio_config": {
|
| 7 |
+
"activation_dropout": 0.0,
|
| 8 |
+
"activation_function": "gelu",
|
| 9 |
+
"add_cross_attention": false,
|
| 10 |
+
"architectures": [
|
| 11 |
+
"Qwen2AudioEncoder"
|
| 12 |
+
],
|
| 13 |
+
"attention_dropout": 0.0,
|
| 14 |
+
"bad_words_ids": null,
|
| 15 |
+
"begin_suppress_tokens": null,
|
| 16 |
+
"bos_token_id": null,
|
| 17 |
+
"chunk_size_feed_forward": 0,
|
| 18 |
+
"cross_attention_hidden_size": null,
|
| 19 |
+
"d_model": 1280,
|
| 20 |
+
"decoder_start_token_id": null,
|
| 21 |
+
"diversity_penalty": 0.0,
|
| 22 |
+
"do_sample": false,
|
| 23 |
+
"dropout": 0.0,
|
| 24 |
+
"early_stopping": false,
|
| 25 |
+
"encoder_attention_heads": 20,
|
| 26 |
+
"encoder_ffn_dim": 5120,
|
| 27 |
+
"encoder_layerdrop": 0.0,
|
| 28 |
+
"encoder_layers": 32,
|
| 29 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 30 |
+
"eos_token_id": null,
|
| 31 |
+
"exponential_decay_length_penalty": null,
|
| 32 |
+
"finetuning_task": null,
|
| 33 |
+
"forced_bos_token_id": null,
|
| 34 |
+
"forced_eos_token_id": null,
|
| 35 |
+
"id2label": {
|
| 36 |
+
"0": "LABEL_0",
|
| 37 |
+
"1": "LABEL_1"
|
| 38 |
+
},
|
| 39 |
+
"init_std": 0.02,
|
| 40 |
+
"initializer_range": 0.02,
|
| 41 |
+
"is_decoder": false,
|
| 42 |
+
"is_encoder_decoder": false,
|
| 43 |
+
"label2id": {
|
| 44 |
+
"LABEL_0": 0,
|
| 45 |
+
"LABEL_1": 1
|
| 46 |
+
},
|
| 47 |
+
"length_penalty": 1.0,
|
| 48 |
+
"max_length": 20,
|
| 49 |
+
"max_source_positions": 1500,
|
| 50 |
+
"min_length": 0,
|
| 51 |
+
"model_type": "qwen2_audio_encoder",
|
| 52 |
+
"no_repeat_ngram_size": 0,
|
| 53 |
+
"num_beam_groups": 1,
|
| 54 |
+
"num_beams": 1,
|
| 55 |
+
"num_hidden_layers": 32,
|
| 56 |
+
"num_mel_bins": 128,
|
| 57 |
+
"num_return_sequences": 1,
|
| 58 |
+
"output_attentions": false,
|
| 59 |
+
"output_hidden_states": false,
|
| 60 |
+
"output_scores": false,
|
| 61 |
+
"pad_token_id": null,
|
| 62 |
+
"prefix": null,
|
| 63 |
+
"problem_type": null,
|
| 64 |
+
"pruned_heads": {},
|
| 65 |
+
"remove_invalid_values": false,
|
| 66 |
+
"repetition_penalty": 1.0,
|
| 67 |
+
"return_dict": true,
|
| 68 |
+
"return_dict_in_generate": false,
|
| 69 |
+
"scale_embedding": false,
|
| 70 |
+
"sep_token_id": null,
|
| 71 |
+
"suppress_tokens": null,
|
| 72 |
+
"task_specific_params": null,
|
| 73 |
+
"temperature": 1.0,
|
| 74 |
+
"tf_legacy_loss": false,
|
| 75 |
+
"tie_encoder_decoder": false,
|
| 76 |
+
"tie_word_embeddings": true,
|
| 77 |
+
"tokenizer_class": null,
|
| 78 |
+
"top_k": 50,
|
| 79 |
+
"top_p": 1.0,
|
| 80 |
+
"torch_dtype": "float32",
|
| 81 |
+
"torchscript": false,
|
| 82 |
+
"typical_p": 1.0,
|
| 83 |
+
"use_bfloat16": false
|
| 84 |
+
},
|
| 85 |
+
"audio_model_name_or_path": null,
|
| 86 |
+
"audio_projector_type": "mlp",
|
| 87 |
+
"audio_start_id": 128071,
|
| 88 |
+
"audio_token_id": 128071,
|
| 89 |
+
"auto_map": {
|
| 90 |
+
"AutoConfig": "configuration_vlm.HCXVisionConfig",
|
| 91 |
+
"AutoModelForCausalLM": "modeling_vlm.HCXVisionForCausalLM",
|
| 92 |
+
"AutoModelForSequenceClassification": "modeling_vlm.HCXVisionForSequenceClassification"
|
| 93 |
+
},
|
| 94 |
+
"discrete_audio_config": {
|
| 95 |
+
"model_name_or_path": null,
|
| 96 |
+
"model_type": "cosyvoice2",
|
| 97 |
+
"torch_dtype": "float32"
|
| 98 |
+
},
|
| 99 |
+
"discrete_audio_model_name_or_path": null,
|
| 100 |
+
"discrete_audio_start_id": 128074,
|
| 101 |
+
"discrete_audio_token_id": 128074,
|
| 102 |
+
"discrete_audio_unit_0_id": 128606,
|
| 103 |
+
"discrete_image_start_id": 128069,
|
| 104 |
+
"discrete_image_token_id": 128069,
|
| 105 |
+
"discrete_image_unit_0_id": 135168,
|
| 106 |
+
"discrete_vision_config": {
|
| 107 |
+
"model_name_or_path": null,
|
| 108 |
+
"model_type": "ta_tok",
|
| 109 |
+
"torch_dtype": "float32"
|
| 110 |
+
},
|
| 111 |
+
"discrete_vision_model_name_or_path": null,
|
| 112 |
+
"end_token_id": 128001,
|
| 113 |
+
"eos_token_id": 128001,
|
| 114 |
+
"freeze_audio_projector": true,
|
| 115 |
+
"freeze_before_sampler": false,
|
| 116 |
+
"freeze_decoder": false,
|
| 117 |
+
"freeze_encoder": true,
|
| 118 |
+
"freeze_mm_projector": false,
|
| 119 |
+
"freeze_video_audio_compressor": false,
|
| 120 |
+
"hidden_size": 4096,
|
| 121 |
+
"ignore_index": -100,
|
| 122 |
+
"image_token_id": 128062,
|
| 123 |
+
"img_start_id": 128062,
|
| 124 |
+
"is_safetensor_save": true,
|
| 125 |
+
"max_num_grids": -1,
|
| 126 |
+
"mm_projector_type": "linear",
|
| 127 |
+
"model_type": "vlm",
|
| 128 |
+
"num_queries_vis_abstractor": -1,
|
| 129 |
+
"possible_resolutions": [],
|
| 130 |
+
"proj_pos_emb": true,
|
| 131 |
+
"proj_prenorm": false,
|
| 132 |
+
"q_former_model_name_or_path": null,
|
| 133 |
+
"text_config": {
|
| 134 |
+
"add_cross_attention": false,
|
| 135 |
+
"architectures": [
|
| 136 |
+
"LlamaForCausalLM"
|
| 137 |
+
],
|
| 138 |
+
"attention_bias": false,
|
| 139 |
+
"attention_dropout": 0.0,
|
| 140 |
+
"bad_words_ids": null,
|
| 141 |
+
"begin_suppress_tokens": null,
|
| 142 |
+
"bos_token_id": 128000,
|
| 143 |
+
"chunk_size_feed_forward": 0,
|
| 144 |
+
"cross_attention_hidden_size": null,
|
| 145 |
+
"decoder_start_token_id": null,
|
| 146 |
+
"diversity_penalty": 0.0,
|
| 147 |
+
"do_sample": false,
|
| 148 |
+
"early_stopping": false,
|
| 149 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 150 |
+
"eos_token_id": 128001,
|
| 151 |
+
"exponential_decay_length_penalty": null,
|
| 152 |
+
"finetuning_task": null,
|
| 153 |
+
"forced_bos_token_id": null,
|
| 154 |
+
"forced_eos_token_id": null,
|
| 155 |
+
"head_dim": 128,
|
| 156 |
+
"hidden_act": "silu",
|
| 157 |
+
"hidden_size": 4096,
|
| 158 |
+
"id2label": {
|
| 159 |
+
"0": "LABEL_0",
|
| 160 |
+
"1": "LABEL_1"
|
| 161 |
+
},
|
| 162 |
+
"initializer_range": 0.02,
|
| 163 |
+
"intermediate_size": 12288,
|
| 164 |
+
"is_decoder": false,
|
| 165 |
+
"is_encoder_decoder": false,
|
| 166 |
+
"label2id": {
|
| 167 |
+
"LABEL_0": 0,
|
| 168 |
+
"LABEL_1": 1
|
| 169 |
+
},
|
| 170 |
+
"length_penalty": 1.0,
|
| 171 |
+
"logits_scaling": 1.0,
|
| 172 |
+
"max_length": 20,
|
| 173 |
+
"max_position_embeddings": 8192,
|
| 174 |
+
"min_length": 0,
|
| 175 |
+
"mlp_bias": false,
|
| 176 |
+
"model_type": "llama",
|
| 177 |
+
"no_repeat_ngram_size": 0,
|
| 178 |
+
"num_attention_heads": 32,
|
| 179 |
+
"num_beam_groups": 1,
|
| 180 |
+
"num_beams": 1,
|
| 181 |
+
"num_hidden_layers": 36,
|
| 182 |
+
"num_key_value_heads": 8,
|
| 183 |
+
"num_return_sequences": 1,
|
| 184 |
+
"output_attentions": false,
|
| 185 |
+
"output_hidden_states": false,
|
| 186 |
+
"output_scores": false,
|
| 187 |
+
"pad_token_id": null,
|
| 188 |
+
"prefix": null,
|
| 189 |
+
"pretraining_tp": 1,
|
| 190 |
+
"problem_type": null,
|
| 191 |
+
"pruned_heads": {},
|
| 192 |
+
"remove_invalid_values": false,
|
| 193 |
+
"repetition_penalty": 1.0,
|
| 194 |
+
"return_dict": true,
|
| 195 |
+
"return_dict_in_generate": false,
|
| 196 |
+
"rms_norm_eps": 1e-06,
|
| 197 |
+
"rope_scaling": null,
|
| 198 |
+
"rope_theta": 5000000,
|
| 199 |
+
"sep_token_id": null,
|
| 200 |
+
"suppress_tokens": null,
|
| 201 |
+
"task_specific_params": null,
|
| 202 |
+
"temperature": 1.0,
|
| 203 |
+
"tf_legacy_loss": false,
|
| 204 |
+
"tie_encoder_decoder": false,
|
| 205 |
+
"tie_word_embeddings": false,
|
| 206 |
+
"tokenizer_class": null,
|
| 207 |
+
"top_k": 50,
|
| 208 |
+
"top_p": 1.0,
|
| 209 |
+
"torch_dtype": "float32",
|
| 210 |
+
"torchscript": false,
|
| 211 |
+
"typical_p": 1.0,
|
| 212 |
+
"use_bfloat16": false,
|
| 213 |
+
"use_cache": true,
|
| 214 |
+
"vocab_size": 200704
|
| 215 |
+
},
|
| 216 |
+
"text_model_name_or_path": null,
|
| 217 |
+
"torch_dtype": "float32",
|
| 218 |
+
"transformers_version": "4.52.4",
|
| 219 |
+
"unpad": false,
|
| 220 |
+
"use_1x1_grid": false,
|
| 221 |
+
"use_nth_layer": -2,
|
| 222 |
+
"video_audio_compressor_type": "mambamia",
|
| 223 |
+
"video_audio_start_id": 128070,
|
| 224 |
+
"video_audio_token_id": 128070,
|
| 225 |
+
"video_first_last_frames_slows": null,
|
| 226 |
+
"video_max_num_frames": 120,
|
| 227 |
+
"video_num_queries_fast": null,
|
| 228 |
+
"video_num_queries_slow": null,
|
| 229 |
+
"video_start_id": 128063,
|
| 230 |
+
"video_token_id": 128063,
|
| 231 |
+
"vision_config": {
|
| 232 |
+
"add_cross_attention": false,
|
| 233 |
+
"anyres": false,
|
| 234 |
+
"architectures": [
|
| 235 |
+
"Qwen2_5_VisionTransformerPretrainedModel"
|
| 236 |
+
],
|
| 237 |
+
"bad_words_ids": null,
|
| 238 |
+
"begin_suppress_tokens": null,
|
| 239 |
+
"bos_token_id": null,
|
| 240 |
+
"chunk_size_feed_forward": 0,
|
| 241 |
+
"cross_attention_hidden_size": null,
|
| 242 |
+
"decoder_start_token_id": null,
|
| 243 |
+
"depth": 32,
|
| 244 |
+
"diversity_penalty": 0.0,
|
| 245 |
+
"do_sample": false,
|
| 246 |
+
"early_stopping": false,
|
| 247 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 248 |
+
"eos_token_id": null,
|
| 249 |
+
"exponential_decay_length_penalty": null,
|
| 250 |
+
"finetuning_task": null,
|
| 251 |
+
"forced_bos_token_id": null,
|
| 252 |
+
"forced_eos_token_id": null,
|
| 253 |
+
"fullatt_block_indexes": [
|
| 254 |
+
7,
|
| 255 |
+
15,
|
| 256 |
+
23,
|
| 257 |
+
31
|
| 258 |
+
],
|
| 259 |
+
"hidden_act": "silu",
|
| 260 |
+
"hidden_size": 1280,
|
| 261 |
+
"id2label": {
|
| 262 |
+
"0": "LABEL_0",
|
| 263 |
+
"1": "LABEL_1"
|
| 264 |
+
},
|
| 265 |
+
"in_channels": 3,
|
| 266 |
+
"in_chans": 3,
|
| 267 |
+
"initializer_range": 0.02,
|
| 268 |
+
"intermediate_size": 3456,
|
| 269 |
+
"is_decoder": false,
|
| 270 |
+
"is_encoder_decoder": false,
|
| 271 |
+
"label2id": {
|
| 272 |
+
"LABEL_0": 0,
|
| 273 |
+
"LABEL_1": 1
|
| 274 |
+
},
|
| 275 |
+
"length_penalty": 1.0,
|
| 276 |
+
"max_length": 20,
|
| 277 |
+
"max_num_grids": -1,
|
| 278 |
+
"min_length": 0,
|
| 279 |
+
"model_type": "qwen2_5_vl",
|
| 280 |
+
"no_repeat_ngram_size": 0,
|
| 281 |
+
"num_beam_groups": 1,
|
| 282 |
+
"num_beams": 1,
|
| 283 |
+
"num_heads": 16,
|
| 284 |
+
"num_return_sequences": 1,
|
| 285 |
+
"out_hidden_size": 5120,
|
| 286 |
+
"output_attentions": false,
|
| 287 |
+
"output_hidden_states": false,
|
| 288 |
+
"output_scores": false,
|
| 289 |
+
"pad_token_id": null,
|
| 290 |
+
"patch_size": 14,
|
| 291 |
+
"prefix": null,
|
| 292 |
+
"problem_type": null,
|
| 293 |
+
"pruned_heads": {},
|
| 294 |
+
"remove_invalid_values": false,
|
| 295 |
+
"repetition_penalty": 1.0,
|
| 296 |
+
"return_dict": true,
|
| 297 |
+
"return_dict_in_generate": false,
|
| 298 |
+
"sep_token_id": null,
|
| 299 |
+
"spatial_merge_size": 2,
|
| 300 |
+
"spatial_patch_size": 14,
|
| 301 |
+
"suppress_tokens": null,
|
| 302 |
+
"task_specific_params": null,
|
| 303 |
+
"temperature": 1.0,
|
| 304 |
+
"temporal_patch_size": 2,
|
| 305 |
+
"tf_legacy_loss": false,
|
| 306 |
+
"tie_encoder_decoder": false,
|
| 307 |
+
"tie_word_embeddings": true,
|
| 308 |
+
"tokenizer_class": null,
|
| 309 |
+
"tokens_per_second": 2,
|
| 310 |
+
"top_k": 50,
|
| 311 |
+
"top_p": 1.0,
|
| 312 |
+
"torch_dtype": "float32",
|
| 313 |
+
"torchscript": false,
|
| 314 |
+
"typical_p": 1.0,
|
| 315 |
+
"use_bfloat16": false,
|
| 316 |
+
"window_size": 112
|
| 317 |
+
},
|
| 318 |
+
"vision_input_chunk_size": null,
|
| 319 |
+
"vision_model_name_or_path": null
|
| 320 |
+
}
|
configuration_hyperclovax.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""LLaMA model configuration"""
|
| 21 |
+
|
| 22 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 23 |
+
|
| 24 |
+
# from transformers.modeling_rope_utils import rope_config_validation
|
| 25 |
+
# from transformers import PretrainedConfig, rope_config_validation
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class HyperCLOVAXConfig(PretrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
|
| 31 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 32 |
+
defaults will yield a similar configuration to that of the LLaMA-7B.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 40 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
| 41 |
+
`inputs_ids` passed when calling [`LlamaModel`]
|
| 42 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 43 |
+
Dimension of the hidden representations.
|
| 44 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 45 |
+
Dimension of the MLP representations.
|
| 46 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 47 |
+
Number of hidden layers in the Transformer decoder.
|
| 48 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 49 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 50 |
+
num_key_value_heads (`int`, *optional*):
|
| 51 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 52 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 53 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 54 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 55 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 56 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 57 |
+
`num_attention_heads`.
|
| 58 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 59 |
+
The non-linear activation function (function or string) in the decoder.
|
| 60 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 61 |
+
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
|
| 62 |
+
Llama 2 up to 4096, CodeLlama up to 16384.
|
| 63 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 64 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 65 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 66 |
+
The epsilon used by the rms normalization layers.
|
| 67 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 69 |
+
relevant if `config.is_decoder=True`.
|
| 70 |
+
pad_token_id (`int`, *optional*):
|
| 71 |
+
Padding token id.
|
| 72 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 73 |
+
Beginning of stream token id.
|
| 74 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 75 |
+
End of stream token id.
|
| 76 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 77 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 78 |
+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
| 79 |
+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
| 80 |
+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 81 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 82 |
+
Whether to tie weight embeddings
|
| 83 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 84 |
+
The base period of the RoPE embeddings.
|
| 85 |
+
rope_scaling (`Dict`, *optional*):
|
| 86 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 87 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 88 |
+
accordingly.
|
| 89 |
+
Expected contents:
|
| 90 |
+
`rope_type` (`str`):
|
| 91 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 92 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 93 |
+
`factor` (`float`, *optional*):
|
| 94 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 95 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 96 |
+
original maximum pre-trained length.
|
| 97 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 98 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 99 |
+
pretraining.
|
| 100 |
+
`attention_factor` (`float`, *optional*):
|
| 101 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 102 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 103 |
+
`factor` field to infer the suggested value.
|
| 104 |
+
`beta_fast` (`float`, *optional*):
|
| 105 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 106 |
+
ramp function. If unspecified, it defaults to 32.
|
| 107 |
+
`beta_slow` (`float`, *optional*):
|
| 108 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 109 |
+
ramp function. If unspecified, it defaults to 1.
|
| 110 |
+
`short_factor` (`List[float]`, *optional*):
|
| 111 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 112 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 113 |
+
size divided by the number of attention heads divided by 2
|
| 114 |
+
`long_factor` (`List[float]`, *optional*):
|
| 115 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 116 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 117 |
+
size divided by the number of attention heads divided by 2
|
| 118 |
+
`low_freq_factor` (`float`, *optional*):
|
| 119 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 120 |
+
`high_freq_factor` (`float`, *optional*):
|
| 121 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 122 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 123 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 124 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 125 |
+
The dropout ratio for the attention probabilities.
|
| 126 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 127 |
+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
| 128 |
+
head_dim (`int`, *optional*):
|
| 129 |
+
The attention head dimension. If None, it will default to hidden_size // num_heads
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
>>> from transformers import LlamaModel, LlamaConfig
|
| 133 |
+
|
| 134 |
+
>>> # Initializing a LLaMA llama-7b style configuration
|
| 135 |
+
>>> configuration = LlamaConfig()
|
| 136 |
+
|
| 137 |
+
>>> # Initializing a model from the llama-7b style configuration
|
| 138 |
+
>>> model = LlamaModel(configuration)
|
| 139 |
+
|
| 140 |
+
>>> # Accessing the model configuration
|
| 141 |
+
>>> configuration = model.config
|
| 142 |
+
```"""
|
| 143 |
+
|
| 144 |
+
model_type = "hyperclovax"
|
| 145 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 146 |
+
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
vocab_size=32000,
|
| 150 |
+
hidden_size=4096,
|
| 151 |
+
intermediate_size=11008,
|
| 152 |
+
num_hidden_layers=32,
|
| 153 |
+
num_attention_heads=32,
|
| 154 |
+
num_key_value_heads=None,
|
| 155 |
+
hidden_act="silu",
|
| 156 |
+
max_position_embeddings=2048,
|
| 157 |
+
initializer_range=0.02,
|
| 158 |
+
rms_norm_eps=1e-6,
|
| 159 |
+
use_cache=True,
|
| 160 |
+
pad_token_id=None,
|
| 161 |
+
bos_token_id=1,
|
| 162 |
+
eos_token_id=2,
|
| 163 |
+
pretraining_tp=1,
|
| 164 |
+
tie_word_embeddings=False,
|
| 165 |
+
rope_theta=10000.0,
|
| 166 |
+
rope_scaling=None,
|
| 167 |
+
attention_bias=False,
|
| 168 |
+
attention_dropout=0.0,
|
| 169 |
+
mlp_bias=False,
|
| 170 |
+
head_dim=None,
|
| 171 |
+
embedding_multiplier=1.0, # mup
|
| 172 |
+
logits_scaling=1.0, # mup
|
| 173 |
+
attention_multiplier=1.0, # mup
|
| 174 |
+
residual_multiplier=1.0, # mup
|
| 175 |
+
use_post_norm=False, # post-norm
|
| 176 |
+
auto_map={
|
| 177 |
+
"AutoConfig": "configuration_hyperclovax.HyperCLOVAXConfig",
|
| 178 |
+
"AutoModel": "modeling_hyperclovax.HyperCLOVAXModel",
|
| 179 |
+
"AutoModelForCausalLM": "modeling_hyperclovax.HyperCLOVAXForCausalLM",
|
| 180 |
+
},
|
| 181 |
+
**kwargs,
|
| 182 |
+
):
|
| 183 |
+
self.vocab_size = vocab_size
|
| 184 |
+
self.max_position_embeddings = max_position_embeddings
|
| 185 |
+
self.hidden_size = hidden_size
|
| 186 |
+
self.intermediate_size = intermediate_size
|
| 187 |
+
self.num_hidden_layers = num_hidden_layers
|
| 188 |
+
self.num_attention_heads = num_attention_heads
|
| 189 |
+
|
| 190 |
+
# for backward compatibility
|
| 191 |
+
if num_key_value_heads is None:
|
| 192 |
+
num_key_value_heads = num_attention_heads
|
| 193 |
+
|
| 194 |
+
self.num_key_value_heads = num_key_value_heads
|
| 195 |
+
self.hidden_act = hidden_act
|
| 196 |
+
self.initializer_range = initializer_range
|
| 197 |
+
self.rms_norm_eps = rms_norm_eps
|
| 198 |
+
self.pretraining_tp = pretraining_tp
|
| 199 |
+
self.use_cache = use_cache
|
| 200 |
+
self.rope_theta = rope_theta
|
| 201 |
+
self.rope_scaling = rope_scaling
|
| 202 |
+
self.attention_bias = attention_bias
|
| 203 |
+
self.attention_dropout = attention_dropout
|
| 204 |
+
self.mlp_bias = mlp_bias
|
| 205 |
+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
| 206 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 207 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 208 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 209 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 210 |
+
# rope_config_validation(self)
|
| 211 |
+
|
| 212 |
+
# mup
|
| 213 |
+
self.embedding_multiplier = embedding_multiplier
|
| 214 |
+
self.logits_scaling = logits_scaling
|
| 215 |
+
self.attention_multiplier = attention_multiplier
|
| 216 |
+
self.residual_multiplier = residual_multiplier
|
| 217 |
+
|
| 218 |
+
# post-norm (dual-norm)
|
| 219 |
+
self.use_post_norm = use_post_norm
|
| 220 |
+
|
| 221 |
+
super().__init__(
|
| 222 |
+
pad_token_id=pad_token_id,
|
| 223 |
+
bos_token_id=bos_token_id,
|
| 224 |
+
eos_token_id=eos_token_id,
|
| 225 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 226 |
+
auto_map=auto_map,
|
| 227 |
+
**kwargs,
|
| 228 |
+
)
|
configuration_vlm.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import transformers
|
| 2 |
+
from transformers import AutoConfig, PretrainedConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class HCXVisionConfig(PretrainedConfig):
|
| 6 |
+
model_type = "vlm"
|
| 7 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
text_config=None,
|
| 12 |
+
vision_config=None,
|
| 13 |
+
discrete_vision_config=None,
|
| 14 |
+
audio_config=None,
|
| 15 |
+
discrete_audio_config=None,
|
| 16 |
+
text_model_name_or_path=None,
|
| 17 |
+
vision_model_name_or_path=None,
|
| 18 |
+
discrete_vision_model_name_or_path=None,
|
| 19 |
+
audio_model_name_or_path=None,
|
| 20 |
+
discrete_audio_model_name_or_path=None,
|
| 21 |
+
q_former_model_name_or_path=None,
|
| 22 |
+
mm_projector_type="mlp",
|
| 23 |
+
audio_projector_type="mlp",
|
| 24 |
+
video_audio_compressor_type=None,
|
| 25 |
+
use_nth_layer=-2,
|
| 26 |
+
img_start_id=128062, # <|IMAGE_PAD|> # Manually adjusted value from previous checkpoint
|
| 27 |
+
discrete_image_start_id=128250, # <|DISCRETE_AUDIO_PAD|>
|
| 28 |
+
discrete_image_unit_0_id=135166, # <|vision00000|>
|
| 29 |
+
video_start_id=128063, # <|VIDEO_PAD|>
|
| 30 |
+
video_audio_start_id=None, # <|VIDEO_AUDIO_PAD|> - will be set dynamically
|
| 31 |
+
audio_start_id=128253, # <|AUDIO_PAD|>
|
| 32 |
+
discrete_audio_start_id=128250, # <|DISCRETE_AUDIO_PAD|>
|
| 33 |
+
discrete_audio_unit_0_id=128604, # <|audio0000|>
|
| 34 |
+
freeze_encoder=False,
|
| 35 |
+
freeze_decoder=False,
|
| 36 |
+
freeze_mm_projector=False,
|
| 37 |
+
freeze_audio_projector=False,
|
| 38 |
+
freeze_video_audio_compressor=False,
|
| 39 |
+
anyres=False,
|
| 40 |
+
unpad=False,
|
| 41 |
+
max_num_grids=-1,
|
| 42 |
+
num_queries_vis_abstractor=-1,
|
| 43 |
+
video_num_queries_fast=None,
|
| 44 |
+
video_num_queries_slow=None,
|
| 45 |
+
video_first_last_frames_slows=None,
|
| 46 |
+
video_max_num_frames=None,
|
| 47 |
+
ignore_index=-100,
|
| 48 |
+
proj_pos_emb=True,
|
| 49 |
+
proj_prenorm=False,
|
| 50 |
+
use_1x1_grid=False,
|
| 51 |
+
possible_resolutions=[],
|
| 52 |
+
**kwargs,
|
| 53 |
+
):
|
| 54 |
+
from transformers import CONFIG_MAPPING
|
| 55 |
+
|
| 56 |
+
if kwargs.get("language_config", None) is not None: # for bc
|
| 57 |
+
text_config = CONFIG_MAPPING[kwargs["language_config"]["model_type"]](**kwargs["language_config"])
|
| 58 |
+
elif text_config is None and text_model_name_or_path is not None:
|
| 59 |
+
text_config = AutoConfig.from_pretrained(text_model_name_or_path, trust_remote_code=True)
|
| 60 |
+
if vision_config is None and vision_model_name_or_path is not None:
|
| 61 |
+
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path, trust_remote_code=True)
|
| 62 |
+
if discrete_vision_config is None and discrete_vision_model_name_or_path is not None:
|
| 63 |
+
discrete_vision_config = {
|
| 64 |
+
"model_type": "ta_tok",
|
| 65 |
+
"model_name_or_path": discrete_vision_model_name_or_path,
|
| 66 |
+
}
|
| 67 |
+
if audio_config is None and audio_model_name_or_path is not None:
|
| 68 |
+
audio_config = AutoConfig.from_pretrained(audio_model_name_or_path)
|
| 69 |
+
if discrete_audio_config is None and discrete_audio_model_name_or_path is not None:
|
| 70 |
+
discrete_audio_config = {
|
| 71 |
+
"model_type": "cosyvoice2",
|
| 72 |
+
"model_name_or_path": discrete_audio_model_name_or_path,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
if isinstance(text_config, dict):
|
| 76 |
+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
| 77 |
+
|
| 78 |
+
if isinstance(vision_config, dict):
|
| 79 |
+
if vision_config["model_type"] == "qwen2_5_vl":
|
| 80 |
+
vision_config["model_type"] = "qwen2_5_vl_visual"
|
| 81 |
+
assert transformers.__version__ >= "4.52.4", "please upgrade transformers to 4.52.4 or higher"
|
| 82 |
+
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
| 83 |
+
|
| 84 |
+
if isinstance(audio_config, dict):
|
| 85 |
+
audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config)
|
| 86 |
+
|
| 87 |
+
self.text_config = text_config
|
| 88 |
+
self.vision_config = vision_config
|
| 89 |
+
self.discrete_vision_config = discrete_vision_config
|
| 90 |
+
self.audio_config = audio_config
|
| 91 |
+
self.discrete_audio_config = discrete_audio_config
|
| 92 |
+
|
| 93 |
+
if text_config is not None:
|
| 94 |
+
# deepspeed zero3에서 config의 hidden_size를 보고 메모리 크기를 자동으로 결정함.
|
| 95 |
+
self.hidden_size = text_config.hidden_size if hasattr(text_config, "hidden_size") else text_config.n_embd
|
| 96 |
+
# add VLM configs
|
| 97 |
+
self.text_model_name_or_path = text_model_name_or_path
|
| 98 |
+
self.vision_model_name_or_path = vision_model_name_or_path
|
| 99 |
+
self.discrete_vision_model_name_or_path = discrete_vision_model_name_or_path
|
| 100 |
+
self.audio_model_name_or_path = audio_model_name_or_path
|
| 101 |
+
self.discrete_audio_model_name_or_path = discrete_audio_model_name_or_path
|
| 102 |
+
self.q_former_model_name_or_path = q_former_model_name_or_path
|
| 103 |
+
self.mm_projector_type = mm_projector_type
|
| 104 |
+
self.audio_projector_type = audio_projector_type
|
| 105 |
+
self.video_audio_compressor_type = video_audio_compressor_type
|
| 106 |
+
self.use_nth_layer = use_nth_layer
|
| 107 |
+
self.freeze_encoder = freeze_encoder
|
| 108 |
+
self.freeze_decoder = freeze_decoder
|
| 109 |
+
self.freeze_mm_projector = freeze_mm_projector
|
| 110 |
+
self.freeze_audio_projector = freeze_audio_projector
|
| 111 |
+
self.freeze_video_audio_compressor = freeze_video_audio_compressor
|
| 112 |
+
self.anyres = anyres
|
| 113 |
+
self.unpad = unpad
|
| 114 |
+
self.max_num_grids = max_num_grids
|
| 115 |
+
self.num_queries_vis_abstractor = num_queries_vis_abstractor
|
| 116 |
+
self.video_num_queries_fast = video_num_queries_fast
|
| 117 |
+
self.video_num_queries_slow = video_num_queries_slow
|
| 118 |
+
self.video_first_last_frames_slows = video_first_last_frames_slows
|
| 119 |
+
self.video_max_num_frames = video_max_num_frames
|
| 120 |
+
|
| 121 |
+
self.img_start_id = img_start_id
|
| 122 |
+
self.image_token_id = img_start_id
|
| 123 |
+
|
| 124 |
+
self.discrete_image_start_id = discrete_image_start_id
|
| 125 |
+
self.discrete_image_token_id = discrete_image_start_id
|
| 126 |
+
self.discrete_image_unit_0_id = discrete_image_unit_0_id
|
| 127 |
+
|
| 128 |
+
self.video_start_id = video_start_id
|
| 129 |
+
self.video_token_id = video_start_id
|
| 130 |
+
|
| 131 |
+
self.video_audio_start_id = video_audio_start_id
|
| 132 |
+
self.video_audio_token_id = video_audio_start_id
|
| 133 |
+
|
| 134 |
+
self.audio_start_id = audio_start_id
|
| 135 |
+
self.audio_token_id = audio_start_id
|
| 136 |
+
|
| 137 |
+
self.discrete_audio_start_id = discrete_audio_start_id
|
| 138 |
+
self.discrete_audio_token_id = discrete_audio_start_id
|
| 139 |
+
self.discrete_audio_unit_0_id = discrete_audio_unit_0_id
|
| 140 |
+
|
| 141 |
+
self.ignore_index = ignore_index
|
| 142 |
+
self.proj_pos_emb = proj_pos_emb
|
| 143 |
+
self.proj_prenorm = proj_prenorm
|
| 144 |
+
self.use_1x1_grid = use_1x1_grid
|
| 145 |
+
self.possible_resolutions = possible_resolutions
|
| 146 |
+
super().__init__(**kwargs)
|
| 147 |
+
if self.text_config is not None: # needed for HCXVisionForSequenceClassification
|
| 148 |
+
self.pad_token_id = self.text_config.pad_token_id
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
AutoConfig.register("vlm", HCXVisionConfig)
|
| 152 |
+
try:
|
| 153 |
+
from .configuration_hyperclovax import HyperCLOVAXConfig
|
| 154 |
+
|
| 155 |
+
AutoConfig.register("hyperclovax", HyperCLOVAXConfig)
|
| 156 |
+
except:
|
| 157 |
+
pass
|
| 158 |
+
try:
|
| 159 |
+
from transformers import CONFIG_MAPPING, MODEL_MAPPING
|
| 160 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
| 161 |
+
Qwen2_5_VisionTransformerPretrainedModel,
|
| 162 |
+
Qwen2_5_VLPatchMerger,
|
| 163 |
+
Qwen2_5_VLVisionConfig,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
MODEL_MAPPING.register(Qwen2_5_VLVisionConfig, Qwen2_5_VisionTransformerPretrainedModel)
|
| 167 |
+
CONFIG_MAPPING.register("qwen2_5_vl_visual", Qwen2_5_VLVisionConfig)
|
| 168 |
+
except:
|
| 169 |
+
pass
|
cosyvoice.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) (Mddct: Dinghao Zhou)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import librosa
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
from torch import nn
|
| 23 |
+
|
| 24 |
+
DEFAULT_SAMPLE_RATE = 16000 # NOTE: 당분간 고정할 예정.
|
| 25 |
+
MIN_DISCRETE_AUDIO_CHUNK_SAMPLES = 1600 # 0.1초, CosyVoice conv 두 번 지나도 code_len >= 1 보장
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class ModelConfig:
|
| 30 |
+
n_mels: int = 128
|
| 31 |
+
n_audio_ctx: int = 1500
|
| 32 |
+
n_audio_state: int = 1280
|
| 33 |
+
n_audio_head: int = 20
|
| 34 |
+
n_audio_layer: int = 6
|
| 35 |
+
n_codebook_size: int = 3**8
|
| 36 |
+
|
| 37 |
+
use_sdpa: bool = True
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scaling=None):
|
| 41 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 42 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 43 |
+
if scaling is not None:
|
| 44 |
+
t = t * scaling
|
| 45 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 46 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 47 |
+
|
| 48 |
+
return torch.cat((freqs_cis, freqs_cis), dim=-1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def apply_rotary_emb(
|
| 52 |
+
xq: torch.Tensor,
|
| 53 |
+
xk: torch.Tensor,
|
| 54 |
+
freqs_cis: torch.Tensor,
|
| 55 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 56 |
+
real = torch.view_as_real(freqs_cis)
|
| 57 |
+
cos, sin = real[:, :, 0], real[:, :, 1]
|
| 58 |
+
cos = cos.unsqueeze(0).unsqueeze(2)
|
| 59 |
+
sin = sin.unsqueeze(0).unsqueeze(2)
|
| 60 |
+
|
| 61 |
+
D = xq.shape[-1]
|
| 62 |
+
half_l, half_r = xq[:, :, :, : D // 2], xq[:, :, :, D // 2 :]
|
| 63 |
+
xq_r = torch.cat((-half_r, half_l), dim=-1)
|
| 64 |
+
|
| 65 |
+
D = xk.shape[-1]
|
| 66 |
+
|
| 67 |
+
half_l, half_r = xk[:, :, :, : D // 2], xk[:, :, :, D // 2 :]
|
| 68 |
+
xk_r = torch.cat((-half_r, half_l), dim=-1)
|
| 69 |
+
|
| 70 |
+
return xq * cos + xq_r * sin, xk * cos + xk_r * sin
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LayerNorm(nn.LayerNorm):
|
| 74 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
return super().forward(x.float()).type(x.dtype)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Linear(nn.Linear):
|
| 79 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 80 |
+
return F.linear(
|
| 81 |
+
x,
|
| 82 |
+
self.weight.to(x.dtype),
|
| 83 |
+
None if self.bias is None else self.bias.to(x.dtype),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Conv1d(nn.Conv1d):
|
| 88 |
+
def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
|
| 89 |
+
return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class MultiHeadAttention(nn.Module):
|
| 93 |
+
def __init__(self, n_state: int, n_head: int, use_sdpa: bool = True):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.n_head = n_head
|
| 96 |
+
self.query = Linear(n_state, n_state)
|
| 97 |
+
self.key = Linear(n_state, n_state, bias=False)
|
| 98 |
+
self.value = Linear(n_state, n_state)
|
| 99 |
+
self.out = Linear(n_state, n_state)
|
| 100 |
+
|
| 101 |
+
self.use_sdpa = use_sdpa
|
| 102 |
+
|
| 103 |
+
def forward(
|
| 104 |
+
self,
|
| 105 |
+
x: torch.Tensor,
|
| 106 |
+
mask: Optional[torch.Tensor] = None,
|
| 107 |
+
):
|
| 108 |
+
q = self.query(x)
|
| 109 |
+
k = self.key(x)
|
| 110 |
+
v = self.value(x)
|
| 111 |
+
|
| 112 |
+
wv, qk = self.qkv_attention(q, k, v, mask)
|
| 113 |
+
return self.out(wv), qk
|
| 114 |
+
|
| 115 |
+
def qkv_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None):
|
| 116 |
+
_, _, D = q.shape
|
| 117 |
+
scale = (D // self.n_head) ** -0.25
|
| 118 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
| 119 |
+
k = k.view(*k.shape[:2], self.n_head, -1)
|
| 120 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 121 |
+
|
| 122 |
+
if not self.use_sdpa:
|
| 123 |
+
k = k.permute(0, 2, 3, 1) * scale
|
| 124 |
+
qk = q @ k # (B, n_head, T, T)
|
| 125 |
+
if mask is not None:
|
| 126 |
+
qk = qk + mask
|
| 127 |
+
qk = qk.float()
|
| 128 |
+
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
|
| 129 |
+
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
| 130 |
+
else:
|
| 131 |
+
k = k.permute(0, 2, 1, 3) * scale
|
| 132 |
+
assert mask is not None
|
| 133 |
+
output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, scale=1.0)
|
| 134 |
+
output = output.transpose(1, 2).contiguous().view(q.size(0), -1, D) # (batch, time1, d_model)
|
| 135 |
+
return output, None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class FSQCodebook(torch.nn.Module):
|
| 139 |
+
def __init__(self, dim: int, level: int = 3):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.project_down = torch.nn.Linear(dim, 8)
|
| 142 |
+
self.level = level
|
| 143 |
+
self.embed = None
|
| 144 |
+
|
| 145 |
+
@torch.inference_mode()
|
| 146 |
+
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
x = rearrange(x, "... d -> (...) d")
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
@torch.inference_mode()
|
| 151 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 152 |
+
x_shape = x.shape
|
| 153 |
+
# pre-process
|
| 154 |
+
x = self.preprocess(x)
|
| 155 |
+
# quantize
|
| 156 |
+
h = self.project_down(x).float()
|
| 157 |
+
h = h.tanh()
|
| 158 |
+
h = h * 0.9990000128746033
|
| 159 |
+
h = h.round() + 1
|
| 160 |
+
# h = ((self.level - 1) * h).round() # range [-k, k]
|
| 161 |
+
powers = torch.pow(self.level, torch.arange(2**self.level, device=x.device, dtype=h.dtype))
|
| 162 |
+
mu = torch.sum(h * powers.unsqueeze(0), dim=-1)
|
| 163 |
+
ind = mu.reshape(x_shape[0], x_shape[1]).int()
|
| 164 |
+
return ind
|
| 165 |
+
|
| 166 |
+
@torch.inference_mode()
|
| 167 |
+
def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
raise NotImplementedError("There is no official up project component provided")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class FSQVectorQuantization(torch.nn.Module):
|
| 172 |
+
"""Vector quantization implementation (inference-only).
|
| 173 |
+
Args:
|
| 174 |
+
dim (int): Dimension
|
| 175 |
+
codebook_size (int): Codebook size
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
dim: int,
|
| 181 |
+
codebook_size: int,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
assert 3**8 == codebook_size
|
| 185 |
+
self._codebook = FSQCodebook(dim=dim, level=3)
|
| 186 |
+
self.codebook_size = codebook_size
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def codebook(self):
|
| 190 |
+
return self._codebook.embed
|
| 191 |
+
|
| 192 |
+
@torch.inference_mode()
|
| 193 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 194 |
+
return self._codebook.encode(x)
|
| 195 |
+
|
| 196 |
+
@torch.inference_mode()
|
| 197 |
+
def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
|
| 198 |
+
quantize = self._codebook.decode(embed_ind)
|
| 199 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 200 |
+
return quantize
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class FSMNMultiHeadAttention(MultiHeadAttention):
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
n_state: int,
|
| 207 |
+
n_head: int,
|
| 208 |
+
kernel_size: int = 31,
|
| 209 |
+
use_sdpa: bool = True,
|
| 210 |
+
):
|
| 211 |
+
super().__init__(n_state, n_head)
|
| 212 |
+
|
| 213 |
+
self.fsmn_block = torch.nn.Conv1d(
|
| 214 |
+
n_state, n_state, kernel_size, stride=1, padding=0, groups=n_state, bias=False
|
| 215 |
+
)
|
| 216 |
+
self.left_padding = (kernel_size - 1) // 2
|
| 217 |
+
self.right_padding = kernel_size - 1 - self.left_padding
|
| 218 |
+
self.pad_fn = torch.nn.ConstantPad1d((self.left_padding, self.right_padding), 0.0)
|
| 219 |
+
|
| 220 |
+
self.use_sdpa = use_sdpa
|
| 221 |
+
|
| 222 |
+
def forward_fsmn(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None):
|
| 223 |
+
b, t, _, _ = inputs.size()
|
| 224 |
+
inputs = inputs.view(b, t, -1)
|
| 225 |
+
if mask is not None and mask.size(2) > 0: # time2 > 0
|
| 226 |
+
inputs = inputs * mask
|
| 227 |
+
x = inputs.transpose(1, 2)
|
| 228 |
+
x = self.pad_fn(x)
|
| 229 |
+
x = self.fsmn_block(x)
|
| 230 |
+
x = x.transpose(1, 2)
|
| 231 |
+
x += inputs
|
| 232 |
+
return x * mask
|
| 233 |
+
|
| 234 |
+
def qkv_attention(
|
| 235 |
+
self,
|
| 236 |
+
q: torch.Tensor,
|
| 237 |
+
k: torch.Tensor,
|
| 238 |
+
v: torch.Tensor,
|
| 239 |
+
mask: Optional[torch.Tensor] = None,
|
| 240 |
+
mask_pad: Optional[torch.Tensor] = None,
|
| 241 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 242 |
+
):
|
| 243 |
+
_, _, D = q.shape
|
| 244 |
+
scale = (D // self.n_head) ** -0.25
|
| 245 |
+
q = q.view(*q.shape[:2], self.n_head, -1)
|
| 246 |
+
k = k.view(*k.shape[:2], self.n_head, -1)
|
| 247 |
+
v = v.view(*v.shape[:2], self.n_head, -1)
|
| 248 |
+
|
| 249 |
+
if freqs_cis is not None:
|
| 250 |
+
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
|
| 251 |
+
|
| 252 |
+
fsm_memory = self.forward_fsmn(v, mask_pad)
|
| 253 |
+
|
| 254 |
+
q = q.permute(0, 2, 1, 3) * scale
|
| 255 |
+
v = v.permute(0, 2, 1, 3)
|
| 256 |
+
|
| 257 |
+
if not self.use_sdpa:
|
| 258 |
+
k = k.permute(0, 2, 3, 1) * scale
|
| 259 |
+
qk = q @ k # (B, n_head, T, T)
|
| 260 |
+
if mask is not None:
|
| 261 |
+
qk = qk + mask
|
| 262 |
+
qk = qk.float()
|
| 263 |
+
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
|
| 264 |
+
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach(), fsm_memory
|
| 265 |
+
else:
|
| 266 |
+
k = k.permute(0, 2, 1, 3) * scale
|
| 267 |
+
assert mask is not None
|
| 268 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
| 269 |
+
q,
|
| 270 |
+
k,
|
| 271 |
+
v,
|
| 272 |
+
attn_mask=mask,
|
| 273 |
+
dropout_p=0.0,
|
| 274 |
+
scale=1.0,
|
| 275 |
+
)
|
| 276 |
+
output = output.transpose(1, 2).contiguous().view(q.size(0), -1, D) # (batch, time1, d_model)
|
| 277 |
+
return output, None, fsm_memory
|
| 278 |
+
|
| 279 |
+
def forward(
|
| 280 |
+
self,
|
| 281 |
+
x: torch.Tensor,
|
| 282 |
+
mask: Optional[torch.Tensor] = None,
|
| 283 |
+
mask_pad: Optional[torch.Tensor] = None,
|
| 284 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 285 |
+
):
|
| 286 |
+
q = self.query(x)
|
| 287 |
+
k = self.key(x)
|
| 288 |
+
v = self.value(x)
|
| 289 |
+
|
| 290 |
+
wv, qk, fsm_memory = self.qkv_attention(q, k, v, mask, mask_pad, freqs_cis)
|
| 291 |
+
return self.out(wv) + fsm_memory, qk
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class ResidualAttentionBlock(torch.nn.Module):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
n_state: int,
|
| 298 |
+
n_head: int,
|
| 299 |
+
kernel_size: int = 31,
|
| 300 |
+
use_sdpa: bool = False,
|
| 301 |
+
):
|
| 302 |
+
super().__init__()
|
| 303 |
+
|
| 304 |
+
self.attn = FSMNMultiHeadAttention(n_state, n_head, kernel_size, use_sdpa=use_sdpa)
|
| 305 |
+
self.attn_ln = LayerNorm(n_state, eps=1e-6)
|
| 306 |
+
|
| 307 |
+
n_mlp = n_state * 4
|
| 308 |
+
|
| 309 |
+
self.mlp = torch.nn.Sequential(Linear(n_state, n_mlp), torch.nn.GELU(), Linear(n_mlp, n_state))
|
| 310 |
+
self.mlp_ln = LayerNorm(n_state)
|
| 311 |
+
|
| 312 |
+
def forward(
|
| 313 |
+
self,
|
| 314 |
+
x: torch.Tensor,
|
| 315 |
+
mask: Optional[torch.Tensor] = None,
|
| 316 |
+
mask_pad: Optional[torch.Tensor] = None,
|
| 317 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 318 |
+
):
|
| 319 |
+
x = x + self.attn(self.attn_ln(x), mask=mask, mask_pad=mask_pad, freqs_cis=freqs_cis)[0]
|
| 320 |
+
|
| 321 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 322 |
+
return x
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class AudioEncoderV2(torch.nn.Module):
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
n_mels: int,
|
| 329 |
+
n_state: int,
|
| 330 |
+
n_head: int,
|
| 331 |
+
n_layer: int,
|
| 332 |
+
stride: int,
|
| 333 |
+
use_sdpa: bool,
|
| 334 |
+
):
|
| 335 |
+
super().__init__()
|
| 336 |
+
self.stride = stride
|
| 337 |
+
|
| 338 |
+
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=stride, padding=1)
|
| 339 |
+
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
| 340 |
+
self.freqs_cis = precompute_freqs_cis(64, 1024 * 2)
|
| 341 |
+
self.blocks = torch.nn.ModuleList(
|
| 342 |
+
[ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa) for _ in range(n_layer)]
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def forward(self, x: torch.Tensor, x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 346 |
+
"""
|
| 347 |
+
x : torch.Tensor, shape = (batch_size, n_mels, T)
|
| 348 |
+
the mel spectrogram of the audio
|
| 349 |
+
x_len: torch.Tensor, shape = (batch_size,)
|
| 350 |
+
length of each audio in x
|
| 351 |
+
"""
|
| 352 |
+
mask = self.make_non_pad_mask(x_len).unsqueeze(1)
|
| 353 |
+
x = torch.nn.functional.gelu(self.conv1(x * mask))
|
| 354 |
+
x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
|
| 355 |
+
mask = self.make_non_pad_mask(x_len).unsqueeze(1)
|
| 356 |
+
x = torch.nn.functional.gelu(self.conv2(x * mask))
|
| 357 |
+
x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
|
| 358 |
+
mask = self.make_non_pad_mask(x_len).unsqueeze(1)
|
| 359 |
+
x = x.permute(0, 2, 1) # (B, T // 2, n_state)
|
| 360 |
+
freqs_cis = self.freqs_cis.to(x.device)
|
| 361 |
+
mask_pad = mask.transpose(1, 2)
|
| 362 |
+
mask = self.mask_to_bias(mask, x.dtype)
|
| 363 |
+
|
| 364 |
+
tmp = torch.view_as_real(freqs_cis)
|
| 365 |
+
cos, sin = tmp[:, :, 0], tmp[:, :, 1]
|
| 366 |
+
|
| 367 |
+
cos = torch.cat((cos, cos), dim=-1)
|
| 368 |
+
sin = torch.cat((sin, sin), dim=-1)
|
| 369 |
+
cos = cos.unsqueeze(0).unsqueeze(2)
|
| 370 |
+
sin = sin.unsqueeze(0).unsqueeze(2)
|
| 371 |
+
|
| 372 |
+
for block in self.blocks:
|
| 373 |
+
x = block(x, mask.unsqueeze(1), mask_pad, freqs_cis[: x.size(1)])
|
| 374 |
+
|
| 375 |
+
return x, x_len
|
| 376 |
+
|
| 377 |
+
@staticmethod
|
| 378 |
+
def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 379 |
+
"""Make mask tensor containing indices of non-padded part.
|
| 380 |
+
The sequences in a batch may have different lengths. To enable
|
| 381 |
+
batch computing, padding is need to make all sequence in same
|
| 382 |
+
size. To avoid the padding part pass value to context dependent
|
| 383 |
+
block such as attention or convolution , this padding part is
|
| 384 |
+
masked.
|
| 385 |
+
1 for non-padded part and 0 for padded part.
|
| 386 |
+
Parameters
|
| 387 |
+
----------
|
| 388 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
| 389 |
+
Returns:
|
| 390 |
+
-------
|
| 391 |
+
torch.Tensor: Mask tensor containing indices of padded part (B, max_T).
|
| 392 |
+
Examples:
|
| 393 |
+
>>> import torch
|
| 394 |
+
>>> import s3tokenizer
|
| 395 |
+
>>> lengths = torch.tensor([5, 3, 2])
|
| 396 |
+
>>> masks = s3tokenizer.make_non_pad_mask(lengths)
|
| 397 |
+
masks = [[1, 1, 1, 1, 1],
|
| 398 |
+
[1, 1, 1, 0, 0],
|
| 399 |
+
[1, 1, 0, 0, 0]]
|
| 400 |
+
"""
|
| 401 |
+
batch_size = lengths.size(0)
|
| 402 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
| 403 |
+
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
|
| 404 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
| 405 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
| 406 |
+
mask = seq_range_expand >= seq_length_expand
|
| 407 |
+
return ~mask
|
| 408 |
+
|
| 409 |
+
@staticmethod
|
| 410 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 411 |
+
"""Convert bool-tensor to float-tensor for flash attention.
|
| 412 |
+
Parameters
|
| 413 |
+
----------
|
| 414 |
+
lengths (torch.Tensor): Batch of lengths (B, ?).
|
| 415 |
+
Returns:
|
| 416 |
+
-------
|
| 417 |
+
torch.Tensor: Mask tensor containing indices of padded part (B, ?).
|
| 418 |
+
Examples:
|
| 419 |
+
>>> import torch
|
| 420 |
+
>>> import s3tokenizer
|
| 421 |
+
>>> lengths = torch.tensor([5, 3, 2])
|
| 422 |
+
>>> masks = self.make_non_pad_mask(lengths)
|
| 423 |
+
masks = [[1, 1, 1, 1, 1],
|
| 424 |
+
[1, 1, 1, 0, 0],
|
| 425 |
+
[1, 1, 0, 0, 0]]
|
| 426 |
+
>>> new_masks = self.mask_to_bias(masks, torch.float32)
|
| 427 |
+
new_masks =
|
| 428 |
+
[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
|
| 429 |
+
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10],
|
| 430 |
+
[-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]]
|
| 431 |
+
"""
|
| 432 |
+
assert mask.dtype == torch.bool
|
| 433 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
| 434 |
+
mask = mask.to(dtype)
|
| 435 |
+
|
| 436 |
+
# attention mask bias
|
| 437 |
+
# NOTE(Mddct): torch.finfo jit issues
|
| 438 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
| 439 |
+
mask = (1.0 - mask) * -1.0e10
|
| 440 |
+
return mask
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class CosyvoiceEncoder(nn.Module):
|
| 444 |
+
"""S3 tokenizer of the CosyVoice2 implementation (inference-only).
|
| 445 |
+
Args:
|
| 446 |
+
config (ModelConfig): Config
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
def __init__(self, config: ModelConfig = ModelConfig()):
|
| 450 |
+
super().__init__()
|
| 451 |
+
self.config = config
|
| 452 |
+
self.encoder = AudioEncoderV2(
|
| 453 |
+
self.config.n_mels,
|
| 454 |
+
self.config.n_audio_state,
|
| 455 |
+
self.config.n_audio_head,
|
| 456 |
+
self.config.n_audio_layer,
|
| 457 |
+
2,
|
| 458 |
+
self.config.use_sdpa,
|
| 459 |
+
)
|
| 460 |
+
self.quantizer = FSQVectorQuantization(
|
| 461 |
+
self.config.n_audio_state,
|
| 462 |
+
self.config.n_codebook_size,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def forward(self, wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 466 |
+
mel = self.mel_spectrogram(wav, n_mels=self.config.n_mels)
|
| 467 |
+
mel_len = torch.tensor([mel.shape[-1]]).to(self.device)
|
| 468 |
+
return self.quantize(mel, mel_len)
|
| 469 |
+
|
| 470 |
+
@torch.inference_mode()
|
| 471 |
+
def quantize(self, mel: torch.Tensor, mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 472 |
+
hidden, code_len = self.encoder(mel, mel_len)
|
| 473 |
+
code = self.quantizer.encode(hidden)
|
| 474 |
+
return code
|
| 475 |
+
|
| 476 |
+
@staticmethod
|
| 477 |
+
def mel_spectrogram(
|
| 478 |
+
wav: torch.Tensor,
|
| 479 |
+
n_mels: int = 80,
|
| 480 |
+
padding: int = 0,
|
| 481 |
+
) -> torch.Tensor:
|
| 482 |
+
"""
|
| 483 |
+
This method is based on the whisper.log_mel_spectrogram().
|
| 484 |
+
So, don't use this as a general mel spectrogram function.
|
| 485 |
+
"""
|
| 486 |
+
device = wav.device
|
| 487 |
+
if padding > 0:
|
| 488 |
+
wav = torch.nn.functional.pad(wav, (0, padding))
|
| 489 |
+
|
| 490 |
+
window = torch.hann_window(400).to(device)
|
| 491 |
+
stft = torch.stft(wav, 400, 160, window=window, return_complex=True)
|
| 492 |
+
mag = stft[..., :-1].abs() ** 2
|
| 493 |
+
|
| 494 |
+
filters = torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=n_mels)).to(device)
|
| 495 |
+
mel_spec = filters @ mag
|
| 496 |
+
|
| 497 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 498 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 499 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 500 |
+
return log_spec
|
| 501 |
+
|
| 502 |
+
@property
|
| 503 |
+
def device(self):
|
| 504 |
+
return next(self.parameters()).device
|
| 505 |
+
|
| 506 |
+
def freeze(self):
|
| 507 |
+
for p in self.parameters():
|
| 508 |
+
p.requires_grad = False
|
| 509 |
+
|
| 510 |
+
@classmethod
|
| 511 |
+
def from_pretrained(cls, model_path: str):
|
| 512 |
+
model = cls()
|
| 513 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True)
|
| 514 |
+
model.eval()
|
| 515 |
+
model.freeze()
|
| 516 |
+
return model
|
decoder/audio/NCCosybigvganDecoder.mar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b71bace1a8ed9f1eac40d98e99ebe9978a25b2de68c25d89674743d550d9abec
|
| 3 |
+
size 517187360
|
decoder/audio/NCZSCosybigvganDecoder.mar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d3cf46e024952093e19dded52befc61c751e3a759138cc055f8b008d1da34a0
|
| 3 |
+
size 539807544
|
decoder/vision/model_index.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "VisionTokenToImagePipeline",
|
| 3 |
+
"_diffusers_version": "0.32.2",
|
| 4 |
+
"_custom_pipeline": "pipeline",
|
| 5 |
+
"transformer": [
|
| 6 |
+
"pipeline",
|
| 7 |
+
"VisionTransformer"
|
| 8 |
+
],
|
| 9 |
+
"vae": [
|
| 10 |
+
"diffusers",
|
| 11 |
+
"AutoencoderKL"
|
| 12 |
+
],
|
| 13 |
+
"scheduler": [
|
| 14 |
+
"diffusers",
|
| 15 |
+
"FlowMatchEulerDiscreteScheduler"
|
| 16 |
+
],
|
| 17 |
+
"token_embedder": [
|
| 18 |
+
"pipeline",
|
| 19 |
+
"VisionTokenEmbedder"
|
| 20 |
+
],
|
| 21 |
+
"transformer2": [
|
| 22 |
+
"pipeline",
|
| 23 |
+
"VisionTransformer"
|
| 24 |
+
]
|
| 25 |
+
}
|
decoder/vision/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
| 3 |
+
"_diffusers_version": "0.35.2",
|
| 4 |
+
"base_image_seq_len": 256,
|
| 5 |
+
"base_shift": 0.5,
|
| 6 |
+
"invert_sigmas": false,
|
| 7 |
+
"max_image_seq_len": 4096,
|
| 8 |
+
"max_shift": 1.15,
|
| 9 |
+
"num_train_timesteps": 1000,
|
| 10 |
+
"shift": 1.0,
|
| 11 |
+
"shift_terminal": null,
|
| 12 |
+
"stochastic_sampling": false,
|
| 13 |
+
"time_shift_type": "exponential",
|
| 14 |
+
"use_beta_sigmas": false,
|
| 15 |
+
"use_dynamic_shifting": false,
|
| 16 |
+
"use_exponential_sigmas": false,
|
| 17 |
+
"use_karras_sigmas": false
|
| 18 |
+
}
|
decoder/vision/token_embedder/config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "VisionTokenEmbedder",
|
| 3 |
+
"_diffusers_version": "0.35.2",
|
| 4 |
+
"embedding_dim": 1536,
|
| 5 |
+
"token_length": 729,
|
| 6 |
+
"vocab_size": 65536
|
| 7 |
+
}
|
decoder/vision/token_embedder/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d84815b2f681f45dbc5dbbad513b04f7fe3fa4444ca6410a9354555bb3410c7f
|
| 3 |
+
size 201329872
|
decoder/vision/transformer/config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "VisionTransformer",
|
| 3 |
+
"_diffusers_version": "0.35.2",
|
| 4 |
+
"axes_dim": [
|
| 5 |
+
8,
|
| 6 |
+
36,
|
| 7 |
+
36
|
| 8 |
+
],
|
| 9 |
+
"context_in_dim": 1536,
|
| 10 |
+
"depth": 0,
|
| 11 |
+
"depth_single_blocks": 35,
|
| 12 |
+
"guidance_embed": false,
|
| 13 |
+
"hidden_size": 1920,
|
| 14 |
+
"in_channels": 16,
|
| 15 |
+
"mlp_ratio": 4.0,
|
| 16 |
+
"num_heads": 24,
|
| 17 |
+
"qkv_bias": true,
|
| 18 |
+
"theta": 10000,
|
| 19 |
+
"use_patchify": false,
|
| 20 |
+
"vec_in_dim": 1536
|
| 21 |
+
}
|
decoder/vision/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f07c014f8575090455e9d3c024b4f58a8ad480b957f2c45fc6eec4fc08edbe94
|
| 3 |
+
size 3914661840
|
decoder/vision/transformer2/config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "VisionTransformer",
|
| 3 |
+
"_diffusers_version": "0.35.2",
|
| 4 |
+
"axes_dim": [
|
| 5 |
+
6,
|
| 6 |
+
18,
|
| 7 |
+
18
|
| 8 |
+
],
|
| 9 |
+
"context_in_dim": 1536,
|
| 10 |
+
"depth": 0,
|
| 11 |
+
"depth_single_blocks": 25,
|
| 12 |
+
"guidance_embed": false,
|
| 13 |
+
"hidden_size": 1008,
|
| 14 |
+
"in_channels": 16,
|
| 15 |
+
"mlp_ratio": 4.0,
|
| 16 |
+
"num_heads": 24,
|
| 17 |
+
"qkv_bias": true,
|
| 18 |
+
"theta": 10000,
|
| 19 |
+
"use_patchify": false,
|
| 20 |
+
"vec_in_dim": 1536
|
| 21 |
+
}
|
decoder/vision/transformer2/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca76187d0f52336f6d385c4d56f43f0e631a4030343b27647b30baf188bcbc96
|
| 3 |
+
size 777545632
|
decoder/vision/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.35.2",
|
| 4 |
+
"_name_or_path": "black-forest-labs/FLUX.1-schnell",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 16,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 1024,
|
| 28 |
+
"scaling_factor": 0.3611,
|
| 29 |
+
"shift_factor": 0.1159,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": false,
|
| 37 |
+
"use_quant_conv": false
|
| 38 |
+
}
|
decoder/vision/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5b59a26851551b67ae1fe58d32e76486e1e812def4696a4bea97f16604d40a3
|
| 3 |
+
size 167666902
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 128000,
|
| 4 |
+
"eos_token_id": 0,
|
| 5 |
+
"transformers_version": "4.52.4"
|
| 6 |
+
}
|
mambamia_videoaudio_compressor.py
ADDED
|
@@ -0,0 +1,803 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# This module is integrated into 'HyperCLOVAX-SEED-Omni-8B' to mitigate
|
| 3 |
+
# audio stream token explosion during hour-long video understanding.
|
| 4 |
+
# It utilizes the MambaMia architecture (AAAI-26 Oral) to
|
| 5 |
+
# effectively compress high-frequency audio tokens into a manageable
|
| 6 |
+
# context for the LLM.
|
| 7 |
+
# Research Context:
|
| 8 |
+
# - MambaMia: https://github.com/naver-ai/mambamia
|
| 9 |
+
# - LLaVA-AV-SSM: https://github.com/naver-ai/LLaVA-AV-SSM
|
| 10 |
+
# Acknowledgements:
|
| 11 |
+
# This implementation is heavily modified and extended from the following
|
| 12 |
+
# foundational repositories:
|
| 13 |
+
# - Transformers: https://github.com/huggingface/transformers (Apache License v2.0)
|
| 14 |
+
# - Mamba: https://github.com/state-spaces/mamba (Apache License v2.0)
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from torch import nn
|
| 23 |
+
from transformers.activations import ACT2FN
|
| 24 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 25 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 26 |
+
from transformers.utils import ModelOutput, logging
|
| 27 |
+
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ============================================================================
|
| 34 |
+
# Check for fast path availability
|
| 35 |
+
# ============================================================================
|
| 36 |
+
if is_mamba_2_ssm_available():
|
| 37 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 38 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
|
| 39 |
+
else:
|
| 40 |
+
selective_state_update = None
|
| 41 |
+
mamba_split_conv1d_scan_combined = None
|
| 42 |
+
|
| 43 |
+
if is_causal_conv1d_available():
|
| 44 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 45 |
+
else:
|
| 46 |
+
causal_conv1d_update, causal_conv1d_fn = None, None
|
| 47 |
+
|
| 48 |
+
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ============================================================================
|
| 52 |
+
# MambaMia2Config (Simplified for v04 only)
|
| 53 |
+
# ============================================================================
|
| 54 |
+
class MambaMia2Config(PretrainedConfig):
|
| 55 |
+
"""
|
| 56 |
+
Simplified MambaMia2 configuration for v04 version only.
|
| 57 |
+
"""
|
| 58 |
+
model_type = "mamba2"
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
num_heads=128,
|
| 63 |
+
head_dim=64,
|
| 64 |
+
vocab_size=32768,
|
| 65 |
+
hidden_size=4096,
|
| 66 |
+
state_size=128,
|
| 67 |
+
num_hidden_layers=64,
|
| 68 |
+
layer_norm_epsilon=1e-5,
|
| 69 |
+
pad_token_id=1,
|
| 70 |
+
bos_token_id=0,
|
| 71 |
+
eos_token_id=2,
|
| 72 |
+
expand=2,
|
| 73 |
+
conv_kernel=4,
|
| 74 |
+
n_groups=8,
|
| 75 |
+
use_bias=False,
|
| 76 |
+
use_conv_bias=True,
|
| 77 |
+
hidden_act="silu",
|
| 78 |
+
initializer_range=0.1,
|
| 79 |
+
residual_in_fp32=False,
|
| 80 |
+
time_step_rank="auto",
|
| 81 |
+
time_step_min=0.001,
|
| 82 |
+
time_step_max=0.1,
|
| 83 |
+
time_step_floor=1e-4,
|
| 84 |
+
time_step_limit=(0.0, float("inf")),
|
| 85 |
+
rescale_prenorm_residual=False,
|
| 86 |
+
use_cache=True,
|
| 87 |
+
norm_before_gate=True,
|
| 88 |
+
rms_norm=True,
|
| 89 |
+
chunk_size=256,
|
| 90 |
+
tie_word_embeddings=False,
|
| 91 |
+
mambamia_chunk_size=10,
|
| 92 |
+
**kwargs,
|
| 93 |
+
):
|
| 94 |
+
self.vocab_size = vocab_size
|
| 95 |
+
self.hidden_size = hidden_size
|
| 96 |
+
self.state_size = state_size
|
| 97 |
+
self.num_hidden_layers = num_hidden_layers
|
| 98 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 99 |
+
self.conv_kernel = conv_kernel
|
| 100 |
+
self.expand = expand
|
| 101 |
+
|
| 102 |
+
self.bos_token_id = bos_token_id
|
| 103 |
+
self.eos_token_id = eos_token_id
|
| 104 |
+
self.pad_token_id = pad_token_id
|
| 105 |
+
self.use_bias = use_bias
|
| 106 |
+
self.use_conv_bias = use_conv_bias
|
| 107 |
+
self.hidden_act = hidden_act
|
| 108 |
+
self.initializer_range = initializer_range
|
| 109 |
+
self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
|
| 110 |
+
self.time_step_min = time_step_min
|
| 111 |
+
self.time_step_max = time_step_max
|
| 112 |
+
self.time_step_floor = time_step_floor
|
| 113 |
+
self.rescale_prenorm_residual = rescale_prenorm_residual
|
| 114 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 115 |
+
self.use_cache = use_cache
|
| 116 |
+
self.n_groups = n_groups
|
| 117 |
+
self.num_heads = num_heads
|
| 118 |
+
self.head_dim = head_dim
|
| 119 |
+
self.norm_before_gate = norm_before_gate
|
| 120 |
+
self.rms_norm = rms_norm
|
| 121 |
+
self.state_size = state_size
|
| 122 |
+
self.chunk_size = chunk_size
|
| 123 |
+
self.time_step_limit = time_step_limit
|
| 124 |
+
self.tie_word_embeddings = tie_word_embeddings
|
| 125 |
+
self.mambamia_chunk_size = mambamia_chunk_size
|
| 126 |
+
self.output_hidden_states = False
|
| 127 |
+
self.output_deltas = False
|
| 128 |
+
|
| 129 |
+
super().__init__(
|
| 130 |
+
bos_token_id=bos_token_id,
|
| 131 |
+
eos_token_id=eos_token_id,
|
| 132 |
+
pad_token_id=pad_token_id,
|
| 133 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 134 |
+
**kwargs,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ============================================================================
|
| 139 |
+
# Helper Modules
|
| 140 |
+
# ============================================================================
|
| 141 |
+
class MambaRMSNormGated(nn.Module):
|
| 142 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 145 |
+
self.variance_epsilon = eps
|
| 146 |
+
|
| 147 |
+
def forward(self, hidden_states, gate=None):
|
| 148 |
+
input_dtype = hidden_states.dtype
|
| 149 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 150 |
+
if gate is not None:
|
| 151 |
+
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
|
| 152 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 153 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 154 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class MambaMia2RMSNorm(nn.Module):
|
| 158 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 161 |
+
self.variance_epsilon = eps
|
| 162 |
+
|
| 163 |
+
def forward(self, hidden_states):
|
| 164 |
+
input_dtype = hidden_states.dtype
|
| 165 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 166 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 167 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 168 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ============================================================================
|
| 172 |
+
# MambaMia2Mixer (v04 version - unidirectional with GPA)
|
| 173 |
+
# ============================================================================
|
| 174 |
+
class MambaMia2Mixer(nn.Module):
|
| 175 |
+
"""
|
| 176 |
+
Unidirectional Mamba2 Mixer for v04 version.
|
| 177 |
+
v04 = v0 (unidirectional Mamba) + GPA (Gated Pooling Attention in Block)
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, config: MambaMia2Config, layer_idx: int):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.num_heads = config.num_heads
|
| 183 |
+
self.hidden_size = config.hidden_size
|
| 184 |
+
self.ssm_state_size = config.state_size
|
| 185 |
+
self.conv_kernel_size = config.conv_kernel
|
| 186 |
+
self.intermediate_size = int(config.expand * self.hidden_size)
|
| 187 |
+
self.time_step_rank = int(config.time_step_rank)
|
| 188 |
+
self.layer_idx = layer_idx
|
| 189 |
+
self.use_conv_bias = config.use_conv_bias
|
| 190 |
+
self.activation = config.hidden_act
|
| 191 |
+
self.act = ACT2FN[config.hidden_act]
|
| 192 |
+
|
| 193 |
+
self.norm_before_gate = config.norm_before_gate
|
| 194 |
+
self.layer_norm_epsilon = config.layer_norm_epsilon
|
| 195 |
+
self.rms_norm = config.rms_norm
|
| 196 |
+
|
| 197 |
+
self.n_groups = config.n_groups
|
| 198 |
+
self.head_dim = config.head_dim
|
| 199 |
+
self.chunk_size = config.chunk_size
|
| 200 |
+
|
| 201 |
+
self.time_step_limit = config.time_step_limit
|
| 202 |
+
self.time_step_min = config.time_step_min
|
| 203 |
+
self.time_step_max = config.time_step_max
|
| 204 |
+
|
| 205 |
+
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
|
| 206 |
+
|
| 207 |
+
# Conv1d for SSM
|
| 208 |
+
self.conv1d = nn.Conv1d(
|
| 209 |
+
in_channels=self.conv_dim,
|
| 210 |
+
out_channels=self.conv_dim,
|
| 211 |
+
bias=config.use_conv_bias,
|
| 212 |
+
kernel_size=config.conv_kernel,
|
| 213 |
+
groups=self.conv_dim,
|
| 214 |
+
padding=config.conv_kernel - 1,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# projection of the input hidden states
|
| 218 |
+
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
|
| 219 |
+
self.in_proj = nn.Linear(self.hidden_size, projection_size, bias=config.use_bias)
|
| 220 |
+
|
| 221 |
+
# time step projection
|
| 222 |
+
self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
|
| 223 |
+
|
| 224 |
+
# S4D real initialization
|
| 225 |
+
A = torch.arange(1, self.num_heads + 1)
|
| 226 |
+
self.A_log = nn.Parameter(torch.log(A))
|
| 227 |
+
self.A_log._no_weight_decay = True
|
| 228 |
+
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
|
| 229 |
+
self.D = nn.Parameter(torch.ones(self.num_heads))
|
| 230 |
+
self.D._no_weight_decay = True
|
| 231 |
+
|
| 232 |
+
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
| 233 |
+
self.use_bias = config.use_bias
|
| 234 |
+
|
| 235 |
+
if not is_fast_path_available:
|
| 236 |
+
logger.warning_once(
|
| 237 |
+
"The fast path is not available because one of "
|
| 238 |
+
"`(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. "
|
| 239 |
+
"Falling back to the naive implementation. To install follow "
|
| 240 |
+
"https://github.com/state-spaces/mamba/#installation and "
|
| 241 |
+
"https://github.com/Dao-AILab/causal-conv1d"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def forward(
|
| 245 |
+
self,
|
| 246 |
+
hidden_states: torch.Tensor,
|
| 247 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 248 |
+
):
|
| 249 |
+
"""
|
| 250 |
+
v04 unidirectional forward pass using CUDA kernels.
|
| 251 |
+
"""
|
| 252 |
+
import os
|
| 253 |
+
rank = int(os.environ.get("RANK", -1))
|
| 254 |
+
debug = False # (rank <= 0)
|
| 255 |
+
|
| 256 |
+
assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, \
|
| 257 |
+
"CUDA kernels required for MambaMia2Mixer"
|
| 258 |
+
|
| 259 |
+
dtype = hidden_states.dtype
|
| 260 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 261 |
+
|
| 262 |
+
if debug:
|
| 263 |
+
print(f"[Mixer DEBUG] input: min={hidden_states.min().item():.6f}, max={hidden_states.max().item():.6f}, nan={torch.isnan(hidden_states).any().item()}, seq_len={seq_len}, chunk_size={self.chunk_size}")
|
| 264 |
+
|
| 265 |
+
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
| 266 |
+
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
|
| 267 |
+
|
| 268 |
+
# Gated MLP's linear projection
|
| 269 |
+
projected_states = self.in_proj(hidden_states)
|
| 270 |
+
|
| 271 |
+
if debug:
|
| 272 |
+
print(f"[Mixer DEBUG] after in_proj: min={projected_states.min().item():.6f}, max={projected_states.max().item():.6f}, nan={torch.isnan(projected_states).any().item()}")
|
| 273 |
+
print(f"[Mixer DEBUG] A_log: {self.A_log[:5].tolist()}, dt_bias: {self.dt_bias[:5].tolist()}, D: {self.D[:5].tolist()}")
|
| 274 |
+
|
| 275 |
+
dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
|
| 276 |
+
|
| 277 |
+
# Unidirectional forward pass (same as v0)
|
| 278 |
+
outputs = mamba_split_conv1d_scan_combined(
|
| 279 |
+
projected_states,
|
| 280 |
+
self.conv1d.weight.squeeze(1),
|
| 281 |
+
self.conv1d.bias,
|
| 282 |
+
self.dt_bias,
|
| 283 |
+
-torch.exp(self.A_log.float()),
|
| 284 |
+
D=self.D,
|
| 285 |
+
chunk_size=self.chunk_size,
|
| 286 |
+
seq_idx=None,
|
| 287 |
+
activation=self.activation,
|
| 288 |
+
rmsnorm_weight=self.norm.weight,
|
| 289 |
+
rmsnorm_eps=self.norm.variance_epsilon,
|
| 290 |
+
outproj_weight=self.out_proj.weight,
|
| 291 |
+
outproj_bias=self.out_proj.bias,
|
| 292 |
+
headdim=self.head_dim,
|
| 293 |
+
ngroups=self.n_groups,
|
| 294 |
+
norm_before_gate=self.norm_before_gate,
|
| 295 |
+
return_final_states=False,
|
| 296 |
+
**dt_limit_kwargs,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if debug:
|
| 300 |
+
print(f"[Mixer DEBUG] after mamba_kernel: min={outputs.min().item():.6f}, max={outputs.max().item():.6f}, nan={torch.isnan(outputs).any().item()}")
|
| 301 |
+
|
| 302 |
+
return outputs.to(dtype)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# ============================================================================
|
| 306 |
+
# MambaMia2Block (v04 version only)
|
| 307 |
+
# ============================================================================
|
| 308 |
+
class MambaMia2Block(nn.Module):
|
| 309 |
+
"""
|
| 310 |
+
Single MambaMia2 block with v04 gated pooling attention mechanism.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
def __init__(self, config: MambaMia2Config, layer_idx: int):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.config = config
|
| 316 |
+
self.layer_idx = layer_idx
|
| 317 |
+
self.residual_in_fp32 = config.residual_in_fp32
|
| 318 |
+
self.mambamia_chunk_size = config.mambamia_chunk_size
|
| 319 |
+
|
| 320 |
+
self.norm = MambaMia2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 321 |
+
self.mixer = MambaMia2Mixer(config, layer_idx=layer_idx)
|
| 322 |
+
|
| 323 |
+
# v04 specific: Gated Pooling Attention (GPA)
|
| 324 |
+
self.drop = nn.Dropout(p=0.1)
|
| 325 |
+
|
| 326 |
+
# Per-frame weight prediction
|
| 327 |
+
self.weight_fc = nn.Linear(config.hidden_size, self.mambamia_chunk_size)
|
| 328 |
+
nn.init.zeros_(self.weight_fc.bias)
|
| 329 |
+
with torch.no_grad():
|
| 330 |
+
self.weight_fc.weight.mul_(1e-3)
|
| 331 |
+
|
| 332 |
+
# Query vs aggregator gating
|
| 333 |
+
self.gate_fc = nn.Linear(config.hidden_size, 1)
|
| 334 |
+
nn.init.zeros_(self.gate_fc.bias)
|
| 335 |
+
with torch.no_grad():
|
| 336 |
+
self.gate_fc.weight.mul_(1e-3)
|
| 337 |
+
|
| 338 |
+
def forward(
|
| 339 |
+
self,
|
| 340 |
+
hidden_states: torch.Tensor,
|
| 341 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 342 |
+
):
|
| 343 |
+
input_dtype = hidden_states.dtype
|
| 344 |
+
residual = hidden_states
|
| 345 |
+
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
| 346 |
+
if self.residual_in_fp32:
|
| 347 |
+
residual = residual.to(torch.float32)
|
| 348 |
+
|
| 349 |
+
# v04 Gated Pooling Attention
|
| 350 |
+
assert hidden_states.dim() == 3, f"hidden_states.dim()={hidden_states.dim()} != 3"
|
| 351 |
+
bsz, seq_len, hidden_dim = hidden_states.shape
|
| 352 |
+
mambamia_chunk_size = self.mambamia_chunk_size
|
| 353 |
+
chunk_with_query = mambamia_chunk_size + 1
|
| 354 |
+
|
| 355 |
+
if seq_len % chunk_with_query != 0:
|
| 356 |
+
raise ValueError(
|
| 357 |
+
f"seq_len={seq_len} must be divisible by (mambamia_chunk_size+1)={chunk_with_query}"
|
| 358 |
+
)
|
| 359 |
+
n_chunk = seq_len // chunk_with_query
|
| 360 |
+
|
| 361 |
+
# Reshape to (bsz, n_chunk, chunk_size+1, hidden_dim)
|
| 362 |
+
hidden_4d = hidden_states.view(bsz, n_chunk, chunk_with_query, hidden_dim)
|
| 363 |
+
|
| 364 |
+
frames = hidden_4d[:, :, :mambamia_chunk_size, :] # (bsz, n_chunk, chunk_size, hidden_dim)
|
| 365 |
+
queries = hidden_4d[:, :, mambamia_chunk_size, :] # (bsz, n_chunk, hidden_dim)
|
| 366 |
+
|
| 367 |
+
# Weight prediction for frames (float32로 계산하여 안정성 확보)
|
| 368 |
+
w_in = self.drop(queries)
|
| 369 |
+
raw_weights = self.weight_fc(w_in)
|
| 370 |
+
alpha = torch.softmax(raw_weights.float(), dim=-1).to(input_dtype) # (bsz, n_chunk, chunk_size)
|
| 371 |
+
|
| 372 |
+
# Weighted average: aggregator
|
| 373 |
+
aggregator = (frames * alpha.unsqueeze(-1)).sum(dim=2) # (bsz, n_chunk, hidden_dim)
|
| 374 |
+
|
| 375 |
+
# Gating between queries and aggregator (float32로 계산)
|
| 376 |
+
gating_in = self.drop(queries)
|
| 377 |
+
gating = torch.sigmoid(self.gate_fc(gating_in).float()).to(input_dtype) # (bsz, n_chunk, 1)
|
| 378 |
+
epsilon = 0.01
|
| 379 |
+
gating = gating * (1 - 2 * epsilon) + epsilon # [0.01, 0.99]
|
| 380 |
+
|
| 381 |
+
gating_broad = gating.expand(-1, -1, hidden_dim)
|
| 382 |
+
aggregator = aggregator * gating_broad
|
| 383 |
+
queries = queries * (1 - gating_broad)
|
| 384 |
+
queries_new = queries + aggregator
|
| 385 |
+
|
| 386 |
+
# Update query positions
|
| 387 |
+
hidden_4d = hidden_4d.clone()
|
| 388 |
+
hidden_4d[:, :, mambamia_chunk_size, :] = queries_new
|
| 389 |
+
hidden_states = hidden_4d.view(bsz, seq_len, hidden_dim)
|
| 390 |
+
|
| 391 |
+
# Mixer forward
|
| 392 |
+
hidden_states = self.mixer(hidden_states, attention_mask=attention_mask)
|
| 393 |
+
|
| 394 |
+
# Residual connection
|
| 395 |
+
hidden_states = hidden_states + residual
|
| 396 |
+
|
| 397 |
+
return hidden_states
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# ============================================================================
|
| 401 |
+
# MambaMia2Model (Simplified)
|
| 402 |
+
# ============================================================================
|
| 403 |
+
@dataclass
|
| 404 |
+
class MambaMia2Output(ModelOutput):
|
| 405 |
+
"""Output class for MambaMia2Model."""
|
| 406 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 407 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class MambaMia2PreTrainedModel(PreTrainedModel):
|
| 411 |
+
"""Base class for MambaMia2 models."""
|
| 412 |
+
config_class = MambaMia2Config
|
| 413 |
+
base_model_prefix = "backbone"
|
| 414 |
+
_no_split_modules = ["MambaMia2Block"]
|
| 415 |
+
supports_gradient_checkpointing = True
|
| 416 |
+
|
| 417 |
+
def _init_weights(self, module):
|
| 418 |
+
if isinstance(module, MambaMia2Mixer):
|
| 419 |
+
module.A_log._no_weight_decay = True
|
| 420 |
+
module.D._no_weight_decay = True
|
| 421 |
+
|
| 422 |
+
dt = torch.exp(
|
| 423 |
+
torch.rand(self.config.num_heads)
|
| 424 |
+
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
|
| 425 |
+
+ math.log(self.config.time_step_min)
|
| 426 |
+
).clamp(min=self.config.time_step_floor)
|
| 427 |
+
|
| 428 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 429 |
+
with torch.no_grad():
|
| 430 |
+
module.dt_bias.copy_(inv_dt)
|
| 431 |
+
module.dt_bias._no_reinit = True
|
| 432 |
+
|
| 433 |
+
if isinstance(module, nn.Linear):
|
| 434 |
+
if module.bias is not None:
|
| 435 |
+
if not getattr(module.bias, "_no_reinit", False):
|
| 436 |
+
nn.init.zeros_(module.bias)
|
| 437 |
+
elif isinstance(module, nn.Embedding):
|
| 438 |
+
nn.init.normal_(module.weight, std=self.config.initializer_range)
|
| 439 |
+
|
| 440 |
+
if self.config.rescale_prenorm_residual:
|
| 441 |
+
for name, p in module.named_parameters():
|
| 442 |
+
if name in ["out_proj.weight"]:
|
| 443 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 444 |
+
with torch.no_grad():
|
| 445 |
+
p /= math.sqrt(self.config.num_hidden_layers)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class MambaMia2Model(MambaMia2PreTrainedModel):
|
| 449 |
+
"""
|
| 450 |
+
Simplified MambaMia2 Model for v04 version.
|
| 451 |
+
Takes inputs_embeds directly (no embedding layer used for audio/video).
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
def __init__(self, config: MambaMia2Config):
|
| 455 |
+
super().__init__(config)
|
| 456 |
+
self.layers = nn.ModuleList([
|
| 457 |
+
MambaMia2Block(config, layer_idx=idx)
|
| 458 |
+
for idx in range(config.num_hidden_layers)
|
| 459 |
+
])
|
| 460 |
+
self.gradient_checkpointing = False
|
| 461 |
+
self.norm_f = MambaMia2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
| 462 |
+
self.post_init()
|
| 463 |
+
|
| 464 |
+
def forward(
|
| 465 |
+
self,
|
| 466 |
+
inputs_embeds: torch.Tensor,
|
| 467 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 468 |
+
output_hidden_states: Optional[bool] = None,
|
| 469 |
+
return_dict: Optional[bool] = None,
|
| 470 |
+
) -> Union[Tuple, MambaMia2Output]:
|
| 471 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 472 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 473 |
+
|
| 474 |
+
hidden_states = inputs_embeds
|
| 475 |
+
all_hidden_states = () if output_hidden_states else None
|
| 476 |
+
|
| 477 |
+
for mixer_block in self.layers:
|
| 478 |
+
if self.gradient_checkpointing and self.training:
|
| 479 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 480 |
+
mixer_block.__call__, hidden_states, attention_mask
|
| 481 |
+
)
|
| 482 |
+
else:
|
| 483 |
+
hidden_states = mixer_block(hidden_states, attention_mask=attention_mask)
|
| 484 |
+
|
| 485 |
+
if output_hidden_states:
|
| 486 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 487 |
+
|
| 488 |
+
hidden_states = self.norm_f(hidden_states)
|
| 489 |
+
|
| 490 |
+
if output_hidden_states:
|
| 491 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 492 |
+
|
| 493 |
+
if not return_dict:
|
| 494 |
+
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
| 495 |
+
|
| 496 |
+
return MambaMia2Output(
|
| 497 |
+
last_hidden_state=hidden_states,
|
| 498 |
+
hidden_states=all_hidden_states,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# ============================================================================
|
| 503 |
+
# MambaMiaVideoAudioCompressorConfig
|
| 504 |
+
# ============================================================================
|
| 505 |
+
class MambaMiaVideoAudioCompressorConfig(PretrainedConfig):
|
| 506 |
+
"""
|
| 507 |
+
Configuration for MambaMiaVideoAudioCompressor.
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
input_size: Input embedding dimension (e.g., 1280 for Whisper)
|
| 511 |
+
output_size: Output embedding dimension (e.g., 2048 for LLM)
|
| 512 |
+
chunk_size: Number of tokens per chunk (default: 25, i.e., 1 second at 25Hz)
|
| 513 |
+
num_hidden_layers: Number of MambaMia2 layers (default: 1)
|
| 514 |
+
hidden_size: Internal hidden size (default: 3072, must be divisible by 24)
|
| 515 |
+
"""
|
| 516 |
+
model_type = "mambamia_videoaudio_compressor"
|
| 517 |
+
|
| 518 |
+
def __init__(
|
| 519 |
+
self,
|
| 520 |
+
input_size: int = 1280,
|
| 521 |
+
output_size: int = 2048,
|
| 522 |
+
chunk_size: int = 25,
|
| 523 |
+
num_hidden_layers: int = 1,
|
| 524 |
+
hidden_size: int = 3072,
|
| 525 |
+
**kwargs,
|
| 526 |
+
):
|
| 527 |
+
super().__init__(**kwargs)
|
| 528 |
+
self.input_size = input_size
|
| 529 |
+
self.output_size = output_size
|
| 530 |
+
self.chunk_size = chunk_size
|
| 531 |
+
self.num_hidden_layers = num_hidden_layers
|
| 532 |
+
self.hidden_size = hidden_size
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
# ============================================================================
|
| 536 |
+
# MambaMiaVideoAudioCompressor - Main Interface (PreTrainedModel 기반)
|
| 537 |
+
# ============================================================================
|
| 538 |
+
class MambaMiaVideoAudioCompressor(PreTrainedModel):
|
| 539 |
+
"""
|
| 540 |
+
Video/Audio Compressor using MambaMia2 (v04 bidirectional version).
|
| 541 |
+
|
| 542 |
+
This module compresses sequential embeddings (e.g., audio frames at 25Hz)
|
| 543 |
+
by inserting learnable query tokens and extracting them after processing.
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
config: MambaMiaVideoAudioCompressorConfig
|
| 547 |
+
|
| 548 |
+
Input:
|
| 549 |
+
inputs_embeds: (batch_size, num_frames, hidden_dim) where num_frames is
|
| 550 |
+
typically the audio length and hidden_dim matches input_size
|
| 551 |
+
|
| 552 |
+
Output:
|
| 553 |
+
compressed_embeds: (batch_size, num_queries, output_size) where
|
| 554 |
+
num_queries = num_frames // chunk_size
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
config_class = MambaMiaVideoAudioCompressorConfig
|
| 558 |
+
base_model_prefix = "mambamia_compressor"
|
| 559 |
+
_no_split_modules = ["MambaMia2Block"]
|
| 560 |
+
|
| 561 |
+
def __init__(self, config: MambaMiaVideoAudioCompressorConfig):
|
| 562 |
+
super().__init__(config)
|
| 563 |
+
|
| 564 |
+
self.input_size = config.input_size
|
| 565 |
+
self.output_size = config.output_size
|
| 566 |
+
self.chunk_size = config.chunk_size
|
| 567 |
+
self.hidden_size = config.hidden_size
|
| 568 |
+
|
| 569 |
+
# Input projection: input_size -> hidden_size
|
| 570 |
+
self.input_proj = nn.Linear(config.input_size, config.hidden_size)
|
| 571 |
+
|
| 572 |
+
# Learnable query token
|
| 573 |
+
self.query_token = nn.Parameter(torch.randn(config.hidden_size))
|
| 574 |
+
|
| 575 |
+
# MambaMia2 backbone
|
| 576 |
+
# 중요: chunk_size는 SSM kernel의 chunk size로, 시퀀스 길이보다 작아야 함
|
| 577 |
+
# mambamia_chunk_size는 압축 비율 (25:1)
|
| 578 |
+
# 시퀀스 길이가 짧을 수 있으므로 (예: 390 tokens) chunk_size=64로 설정
|
| 579 |
+
mamba_config = MambaMia2Config(
|
| 580 |
+
vocab_size=0,
|
| 581 |
+
hidden_size=config.hidden_size,
|
| 582 |
+
num_hidden_layers=config.num_hidden_layers,
|
| 583 |
+
head_dim=64,
|
| 584 |
+
num_heads=config.hidden_size * 2 // 64, # e.g., 3072*2/64 = 96
|
| 585 |
+
n_groups=1,
|
| 586 |
+
expand=2.0,
|
| 587 |
+
use_cache=False,
|
| 588 |
+
chunk_size=256, # SSM kernel chunk size
|
| 589 |
+
mambamia_chunk_size=config.chunk_size, # 압축 비율용 (25)
|
| 590 |
+
residual_in_fp32=False,
|
| 591 |
+
)
|
| 592 |
+
self.model = MambaMia2Model(mamba_config)
|
| 593 |
+
|
| 594 |
+
# LayerNorm before Mamba2 to normalize input scales
|
| 595 |
+
# This ensures query_token and input_proj outputs are on the same scale
|
| 596 |
+
self.input_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
|
| 597 |
+
|
| 598 |
+
# Output projection: hidden_size -> output_size
|
| 599 |
+
self.output_proj = nn.Linear(config.hidden_size, config.output_size)
|
| 600 |
+
|
| 601 |
+
# Initialize weights (transformers style)
|
| 602 |
+
self.post_init()
|
| 603 |
+
|
| 604 |
+
def _init_weights(self, module):
|
| 605 |
+
"""
|
| 606 |
+
Initialize weights - called by post_init() for all submodules.
|
| 607 |
+
주의: MambaMia2Model 내부의 가중치는 건드리지 않음 (자체 post_init에서 처리됨)
|
| 608 |
+
"""
|
| 609 |
+
# query_token 초기화 - std=1.0으로 input_proj 출력 스케일과 맞춤
|
| 610 |
+
# (작은 std는 LayerNorm에서 variance가 0에 가까워져 inf 발생)
|
| 611 |
+
if module is self:
|
| 612 |
+
with torch.no_grad():
|
| 613 |
+
self.query_token.data.normal_(mean=0.0, std=1.0)
|
| 614 |
+
|
| 615 |
+
# input_proj, output_proj만 xavier 초기화 (MambaMia2 내부는 건드리지 않음)
|
| 616 |
+
if module is self.input_proj or module is self.output_proj:
|
| 617 |
+
nn.init.xavier_uniform_(module.weight)
|
| 618 |
+
if module.bias is not None:
|
| 619 |
+
nn.init.zeros_(module.bias)
|
| 620 |
+
|
| 621 |
+
def _init_all_weights(self):
|
| 622 |
+
"""
|
| 623 |
+
Force re-initialize all weights. Call after dtype conversion for FSDP compatibility.
|
| 624 |
+
This ensures weights are properly initialized even after model transformations.
|
| 625 |
+
"""
|
| 626 |
+
# 1. input_proj, output_proj 초기화
|
| 627 |
+
nn.init.xavier_uniform_(self.input_proj.weight)
|
| 628 |
+
if self.input_proj.bias is not None:
|
| 629 |
+
nn.init.zeros_(self.input_proj.bias)
|
| 630 |
+
nn.init.xavier_uniform_(self.output_proj.weight)
|
| 631 |
+
if self.output_proj.bias is not None:
|
| 632 |
+
nn.init.zeros_(self.output_proj.bias)
|
| 633 |
+
|
| 634 |
+
# 2. query_token 초기화 - std=1.0으로 input_proj 출력 스케일과 맞춤
|
| 635 |
+
self.query_token.data.normal_(mean=0.0, std=1.0)
|
| 636 |
+
|
| 637 |
+
# 3. input_norm (LayerNorm) 초기화
|
| 638 |
+
nn.init.ones_(self.input_norm.weight)
|
| 639 |
+
nn.init.zeros_(self.input_norm.bias)
|
| 640 |
+
|
| 641 |
+
# 4. MambaMia2Model 내부 초기화 (중요!)
|
| 642 |
+
for name, module in self.model.named_modules():
|
| 643 |
+
if isinstance(module, nn.Linear):
|
| 644 |
+
nn.init.xavier_uniform_(module.weight)
|
| 645 |
+
if module.bias is not None:
|
| 646 |
+
nn.init.zeros_(module.bias)
|
| 647 |
+
elif isinstance(module, nn.Conv1d):
|
| 648 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
| 649 |
+
if module.bias is not None:
|
| 650 |
+
nn.init.zeros_(module.bias)
|
| 651 |
+
|
| 652 |
+
# 5. MambaMia2Block의 특수 초기화 (weight_fc, gate_fc)
|
| 653 |
+
for layer in self.model.layers:
|
| 654 |
+
if hasattr(layer, 'weight_fc'):
|
| 655 |
+
nn.init.xavier_uniform_(layer.weight_fc.weight)
|
| 656 |
+
layer.weight_fc.weight.data.mul_(0.01) # Scale down
|
| 657 |
+
nn.init.zeros_(layer.weight_fc.bias)
|
| 658 |
+
if hasattr(layer, 'gate_fc'):
|
| 659 |
+
nn.init.xavier_uniform_(layer.gate_fc.weight)
|
| 660 |
+
layer.gate_fc.weight.data.mul_(0.01) # Scale down
|
| 661 |
+
nn.init.zeros_(layer.gate_fc.bias)
|
| 662 |
+
|
| 663 |
+
# 6. A_log, D, dt_bias 파라미터 초기화 (SSM specific)
|
| 664 |
+
for layer in self.model.layers:
|
| 665 |
+
if hasattr(layer, 'mixer'):
|
| 666 |
+
mixer = layer.mixer
|
| 667 |
+
# A_log: S4D real initialization
|
| 668 |
+
A = torch.arange(1, mixer.num_heads + 1, dtype=mixer.A_log.dtype, device=mixer.A_log.device)
|
| 669 |
+
mixer.A_log.data.copy_(torch.log(A))
|
| 670 |
+
# D: scaling factor
|
| 671 |
+
mixer.D.data.fill_(1.0)
|
| 672 |
+
# dt_bias: time step bias (중요!)
|
| 673 |
+
mixer.dt_bias.data.fill_(1.0)
|
| 674 |
+
|
| 675 |
+
# 7. RMSNorm weight 초기화 (MambaRMSNormGated)
|
| 676 |
+
for layer in self.model.layers:
|
| 677 |
+
if hasattr(layer, 'mixer') and hasattr(layer.mixer, 'norm'):
|
| 678 |
+
layer.mixer.norm.weight.data.fill_(1.0)
|
| 679 |
+
if hasattr(layer, 'norm') and hasattr(layer.norm, 'weight'):
|
| 680 |
+
layer.norm.weight.data.fill_(1.0)
|
| 681 |
+
|
| 682 |
+
# 8. MambaMia2Model의 최종 norm_f 초기화
|
| 683 |
+
if hasattr(self.model, 'norm_f') and hasattr(self.model.norm_f, 'weight'):
|
| 684 |
+
self.model.norm_f.weight.data.fill_(1.0)
|
| 685 |
+
|
| 686 |
+
def forward(
|
| 687 |
+
self,
|
| 688 |
+
inputs_embeds: torch.Tensor,
|
| 689 |
+
) -> torch.Tensor:
|
| 690 |
+
"""
|
| 691 |
+
Forward pass.
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
inputs_embeds: (batch_size, seq_len, input_size) or
|
| 695 |
+
(batch_size, num_frames, chunk_size, input_size)
|
| 696 |
+
|
| 697 |
+
Returns:
|
| 698 |
+
compressed: (batch_size, num_queries, output_size)
|
| 699 |
+
"""
|
| 700 |
+
import os
|
| 701 |
+
rank = int(os.environ.get("RANK", -1))
|
| 702 |
+
debug = False # True if (rank <= 0) else False
|
| 703 |
+
|
| 704 |
+
# Handle different input shapes
|
| 705 |
+
if inputs_embeds.dim() == 4:
|
| 706 |
+
# (batch_size, num_frames, chunk_size, input_size)
|
| 707 |
+
bsz, num_frames, chunk_size, _ = inputs_embeds.shape
|
| 708 |
+
assert chunk_size == self.chunk_size, \
|
| 709 |
+
f"Input chunk_size {chunk_size} != expected {self.chunk_size}"
|
| 710 |
+
inputs_embeds = inputs_embeds.view(bsz, -1, self.input_size)
|
| 711 |
+
|
| 712 |
+
bsz, seq_len, _ = inputs_embeds.shape
|
| 713 |
+
|
| 714 |
+
# Ensure seq_len is divisible by chunk_size
|
| 715 |
+
if seq_len % self.chunk_size != 0:
|
| 716 |
+
# Pad to make divisible
|
| 717 |
+
pad_len = self.chunk_size - (seq_len % self.chunk_size)
|
| 718 |
+
inputs_embeds = F.pad(inputs_embeds, (0, 0, 0, pad_len))
|
| 719 |
+
seq_len = inputs_embeds.shape[1]
|
| 720 |
+
|
| 721 |
+
n_chunk = seq_len // self.chunk_size
|
| 722 |
+
|
| 723 |
+
# Project input
|
| 724 |
+
hidden_states = self.input_proj(inputs_embeds) # (bsz, seq_len, hidden_size)
|
| 725 |
+
|
| 726 |
+
if debug:
|
| 727 |
+
print(f"[MambaMia DEBUG] input_proj output: min={hidden_states.min().item():.6f}, max={hidden_states.max().item():.6f}, has_nan={torch.isnan(hidden_states).any().item()}")
|
| 728 |
+
|
| 729 |
+
# Reshape to (bsz, n_chunk, chunk_size, hidden_size)
|
| 730 |
+
hidden_4d = hidden_states.view(bsz, n_chunk, self.chunk_size, self.hidden_size)
|
| 731 |
+
|
| 732 |
+
# Add query token to each chunk
|
| 733 |
+
# query_token: (hidden_size,) -> (1, 1, 1, hidden_size)
|
| 734 |
+
query_expanded = self.query_token.view(1, 1, 1, -1).expand(bsz, n_chunk, 1, self.hidden_size)
|
| 735 |
+
|
| 736 |
+
if debug:
|
| 737 |
+
print(f"[MambaMia DEBUG] query_token: min={self.query_token.min().item():.6f}, max={self.query_token.max().item():.6f}, has_nan={torch.isnan(self.query_token).any().item()}")
|
| 738 |
+
|
| 739 |
+
# Concatenate: (bsz, n_chunk, chunk_size+1, hidden_size)
|
| 740 |
+
hidden_with_query = torch.cat([hidden_4d, query_expanded], dim=2)
|
| 741 |
+
|
| 742 |
+
# Flatten for model: (bsz, n_chunk * (chunk_size+1), hidden_size)
|
| 743 |
+
model_input = hidden_with_query.view(bsz, -1, self.hidden_size)
|
| 744 |
+
|
| 745 |
+
# Apply LayerNorm to normalize input scales before Mamba2
|
| 746 |
+
model_input = self.input_norm(model_input)
|
| 747 |
+
|
| 748 |
+
if debug:
|
| 749 |
+
print(f"[MambaMia DEBUG] model_input (after LayerNorm, before Mamba2): min={model_input.min().item():.6f}, max={model_input.max().item():.6f}, has_nan={torch.isnan(model_input).any().item()}")
|
| 750 |
+
|
| 751 |
+
# Forward through MambaMia2
|
| 752 |
+
outputs = self.model(inputs_embeds=model_input)
|
| 753 |
+
hidden_states = outputs.last_hidden_state # (bsz, n_chunk * (chunk_size+1), hidden_size)
|
| 754 |
+
|
| 755 |
+
if debug:
|
| 756 |
+
print(f"[MambaMia DEBUG] model output (after Mamba2): min={hidden_states.min().item():.6f}, max={hidden_states.max().item():.6f}, has_nan={torch.isnan(hidden_states).any().item()}")
|
| 757 |
+
|
| 758 |
+
# Check for NaN and replace with zeros if found (defensive)
|
| 759 |
+
if torch.isnan(hidden_states).any():
|
| 760 |
+
hidden_states = torch.nan_to_num(hidden_states, nan=0.0)
|
| 761 |
+
|
| 762 |
+
# Reshape back: (bsz, n_chunk, chunk_size+1, hidden_size)
|
| 763 |
+
hidden_out_4d = hidden_states.view(bsz, n_chunk, self.chunk_size + 1, self.hidden_size)
|
| 764 |
+
|
| 765 |
+
# Extract query positions (last position in each chunk)
|
| 766 |
+
query_outputs = hidden_out_4d[:, :, self.chunk_size, :] # (bsz, n_chunk, hidden_size)
|
| 767 |
+
|
| 768 |
+
if debug:
|
| 769 |
+
print(f"[MambaMia DEBUG] query_outputs (extracted): min={query_outputs.min().item():.6f}, max={query_outputs.max().item():.6f}, has_nan={torch.isnan(query_outputs).any().item()}")
|
| 770 |
+
|
| 771 |
+
# Project to output size
|
| 772 |
+
compressed = self.output_proj(query_outputs) # (bsz, n_chunk, output_size)
|
| 773 |
+
|
| 774 |
+
if debug:
|
| 775 |
+
print(f"[MambaMia DEBUG] output_proj output: min={compressed.min().item():.6f}, max={compressed.max().item():.6f}, has_nan={torch.isnan(compressed).any().item()}")
|
| 776 |
+
|
| 777 |
+
return compressed
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
# ============================================================================
|
| 781 |
+
# Convenience function for quick instantiation
|
| 782 |
+
# ============================================================================
|
| 783 |
+
def create_mambamia_compressor(
|
| 784 |
+
input_size: int,
|
| 785 |
+
output_size: int,
|
| 786 |
+
chunk_size: int = 25,
|
| 787 |
+
num_hidden_layers: int = 2,
|
| 788 |
+
hidden_size: int = 3072,
|
| 789 |
+
) -> MambaMiaVideoAudioCompressor:
|
| 790 |
+
"""
|
| 791 |
+
Create a MambaMiaVideoAudioCompressor with default settings.
|
| 792 |
+
|
| 793 |
+
Example:
|
| 794 |
+
compressor = create_mambamia_compressor(1280, 2048, chunk_size=25)
|
| 795 |
+
"""
|
| 796 |
+
config = MambaMiaVideoAudioCompressorConfig(
|
| 797 |
+
input_size=input_size,
|
| 798 |
+
output_size=output_size,
|
| 799 |
+
chunk_size=chunk_size,
|
| 800 |
+
num_hidden_layers=num_hidden_layers,
|
| 801 |
+
hidden_size=hidden_size,
|
| 802 |
+
)
|
| 803 |
+
return MambaMiaVideoAudioCompressor(config)
|
model-00001-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f0c21a9149c81295ffd42490cfb2daf7c9e6dcc39f11a4e4c4b4fe4be8a9e2a
|
| 3 |
+
size 4707522584
|
model-00002-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:91399f86327e37a1b75ca14154cdc42d9994f39c08c34e8156503e89f10cc800
|
| 3 |
+
size 3454903840
|
model-00003-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f779b660d0a569700efef179dda58a3496decb3049b8c8269860f83b32bb647f
|
| 3 |
+
size 4999679056
|
model-00004-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:73022915ccce7a97bb31c0cc60c7b7b52fe7cd4a6859949a1e8f3d60c51e2fb8
|
| 3 |
+
size 4832042296
|
model-00005-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:36449388c5903053f1e6cbac60140feb7b321a006e8493ed0f10480bca540d46
|
| 3 |
+
size 4832042328
|
model-00006-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3696ca372b23a411837f961c00f05a0eaf9a3035438f37525793753bd08823fa
|
| 3 |
+
size 4999848088
|
model-00007-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5481aa36ff6148bb35812b4de5bf2af12daf4e6d07a3106fd72d082088816cab
|
| 3 |
+
size 4832042352
|
model-00008-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:772db476e45a055871d22c837a70fbd2d4960dd075031f3d2b6b5a6ea0888db2
|
| 3 |
+
size 4832042352
|
model-00009-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb63af1e386f7ccf472c4d3108c134d863869ab13c4f9b945586a2994f3258d6
|
| 3 |
+
size 1744948136
|
model-00010-of-00010.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c5a2c76f5b7960665d75ed473ae43eb76a5d86e27d8820ede0b9b16bb40b68c
|
| 3 |
+
size 3731831368
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_hyperclovax.py
ADDED
|
@@ -0,0 +1,1866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
import math
|
| 21 |
+
from typing import List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
import torch.utils.checkpoint
|
| 26 |
+
from torch import nn
|
| 27 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 30 |
+
from transformers.generation import GenerationMixin
|
| 31 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 32 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 33 |
+
from transformers.modeling_outputs import (
|
| 34 |
+
BaseModelOutputWithPast,
|
| 35 |
+
CausalLMOutputWithPast,
|
| 36 |
+
QuestionAnsweringModelOutput,
|
| 37 |
+
SequenceClassifierOutputWithPast,
|
| 38 |
+
TokenClassifierOutput,
|
| 39 |
+
)
|
| 40 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 41 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 42 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 43 |
+
from transformers.utils import (
|
| 44 |
+
add_start_docstrings,
|
| 45 |
+
add_start_docstrings_to_model_forward,
|
| 46 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 47 |
+
is_torchdynamo_compiling,
|
| 48 |
+
logging,
|
| 49 |
+
replace_return_docstrings,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
from .configuration_hyperclovax import HyperCLOVAXConfig
|
| 53 |
+
|
| 54 |
+
logger = logging.get_logger(__name__)
|
| 55 |
+
|
| 56 |
+
_CONFIG_FOR_DOC = "HyperCLOVAXConfig"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 60 |
+
attention_mask: torch.Tensor,
|
| 61 |
+
sequence_length: int,
|
| 62 |
+
target_length: int,
|
| 63 |
+
dtype: torch.dtype,
|
| 64 |
+
device: torch.device,
|
| 65 |
+
min_dtype: float,
|
| 66 |
+
cache_position: torch.Tensor,
|
| 67 |
+
batch_size: int,
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 71 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
attention_mask (`torch.Tensor`):
|
| 75 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
| 76 |
+
sequence_length (`int`):
|
| 77 |
+
The sequence length being processed.
|
| 78 |
+
target_length (`int`):
|
| 79 |
+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
| 80 |
+
dtype (`torch.dtype`):
|
| 81 |
+
The dtype to use for the 4D attention mask.
|
| 82 |
+
device (`torch.device`):
|
| 83 |
+
The device to plcae the 4D attention mask on.
|
| 84 |
+
min_dtype (`float`):
|
| 85 |
+
The minimum value representable with the dtype `dtype`.
|
| 86 |
+
cache_position (`torch.Tensor`):
|
| 87 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 88 |
+
batch_size (`torch.Tensor`):
|
| 89 |
+
Batch size.
|
| 90 |
+
"""
|
| 91 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 92 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 93 |
+
causal_mask = attention_mask
|
| 94 |
+
else:
|
| 95 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
| 96 |
+
if sequence_length != 1:
|
| 97 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 98 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 99 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 100 |
+
if attention_mask is not None:
|
| 101 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 102 |
+
mask_length = attention_mask.shape[-1]
|
| 103 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
| 104 |
+
padding_mask = padding_mask == 0
|
| 105 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
|
| 106 |
+
|
| 107 |
+
return causal_mask
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class HyperCLOVAXRMSNorm(nn.Module):
|
| 111 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 112 |
+
"""
|
| 113 |
+
HyperCLOVAXRMSNorm is equivalent to T5LayerNorm
|
| 114 |
+
"""
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 117 |
+
self.variance_epsilon = eps
|
| 118 |
+
|
| 119 |
+
def forward(self, hidden_states):
|
| 120 |
+
input_dtype = hidden_states.dtype
|
| 121 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 122 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 123 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 124 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 125 |
+
|
| 126 |
+
def extra_repr(self):
|
| 127 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
ALL_LAYERNORM_LAYERS.append(HyperCLOVAXRMSNorm)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class HyperCLOVAXRotaryEmbedding(nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
dim=None,
|
| 137 |
+
max_position_embeddings=2048,
|
| 138 |
+
base=10000,
|
| 139 |
+
device=None,
|
| 140 |
+
scaling_factor=1.0,
|
| 141 |
+
rope_type="default",
|
| 142 |
+
config: Optional[HyperCLOVAXConfig] = None,
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
# TODO (joao): remove the `if` below, only used for BC
|
| 146 |
+
self.rope_kwargs = {}
|
| 147 |
+
if config is None:
|
| 148 |
+
logger.warning_once(
|
| 149 |
+
"`HyperCLOVAXRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
| 150 |
+
"`config` argument. All other arguments will be removed in v4.46"
|
| 151 |
+
)
|
| 152 |
+
self.rope_kwargs = {
|
| 153 |
+
"rope_type": rope_type,
|
| 154 |
+
"factor": scaling_factor,
|
| 155 |
+
"dim": dim,
|
| 156 |
+
"base": base,
|
| 157 |
+
"max_position_embeddings": max_position_embeddings,
|
| 158 |
+
}
|
| 159 |
+
self.rope_type = rope_type
|
| 160 |
+
self.max_seq_len_cached = max_position_embeddings
|
| 161 |
+
self.original_max_seq_len = max_position_embeddings
|
| 162 |
+
else:
|
| 163 |
+
# BC: "rope_type" was originally "type"
|
| 164 |
+
if config.rope_scaling is not None:
|
| 165 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 166 |
+
else:
|
| 167 |
+
self.rope_type = "default"
|
| 168 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 169 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 170 |
+
|
| 171 |
+
self.config = config
|
| 172 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 173 |
+
|
| 174 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
| 175 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 176 |
+
self.original_inv_freq = self.inv_freq
|
| 177 |
+
|
| 178 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
| 179 |
+
"""
|
| 180 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 181 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
| 182 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 183 |
+
"""
|
| 184 |
+
seq_len = torch.max(position_ids) + 1
|
| 185 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 186 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 187 |
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
| 188 |
+
)
|
| 189 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
| 190 |
+
self.max_seq_len_cached = seq_len
|
| 191 |
+
|
| 192 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
| 193 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 194 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
def forward(self, x, position_ids):
|
| 198 |
+
if "dynamic" in self.rope_type:
|
| 199 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 200 |
+
|
| 201 |
+
# Core RoPE block
|
| 202 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 203 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 204 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 205 |
+
device_type = x.device.type
|
| 206 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 207 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 208 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 209 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 210 |
+
cos = emb.cos()
|
| 211 |
+
sin = emb.sin()
|
| 212 |
+
|
| 213 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 214 |
+
cos = cos * self.attention_scaling
|
| 215 |
+
sin = sin * self.attention_scaling
|
| 216 |
+
|
| 217 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class HyperCLOVAXLinearScalingRotaryEmbedding(HyperCLOVAXRotaryEmbedding):
|
| 221 |
+
"""HyperCLOVAXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 222 |
+
|
| 223 |
+
def __init__(self, *args, **kwargs):
|
| 224 |
+
logger.warning_once(
|
| 225 |
+
"`HyperCLOVAXLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
| 226 |
+
"`HyperCLOVAXRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
| 227 |
+
)
|
| 228 |
+
kwargs["rope_type"] = "linear"
|
| 229 |
+
super().__init__(*args, **kwargs)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class HyperCLOVAXDynamicNTKScalingRotaryEmbedding(HyperCLOVAXRotaryEmbedding):
|
| 233 |
+
"""HyperCLOVAXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 234 |
+
|
| 235 |
+
def __init__(self, *args, **kwargs):
|
| 236 |
+
logger.warning_once(
|
| 237 |
+
"`HyperCLOVAXDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
| 238 |
+
"`HyperCLOVAXRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
| 239 |
+
"__init__)."
|
| 240 |
+
)
|
| 241 |
+
kwargs["rope_type"] = "dynamic"
|
| 242 |
+
super().__init__(*args, **kwargs)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def rotate_half(x):
|
| 246 |
+
"""Rotates half the hidden dims of the input."""
|
| 247 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 248 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 249 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 253 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
q (`torch.Tensor`): The query tensor.
|
| 257 |
+
k (`torch.Tensor`): The key tensor.
|
| 258 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 259 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 260 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 261 |
+
Deprecated and unused.
|
| 262 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 263 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 264 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 265 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 266 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 267 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 268 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 269 |
+
Returns:
|
| 270 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 271 |
+
"""
|
| 272 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 273 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 274 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 275 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 276 |
+
return q_embed, k_embed
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class HyperCLOVAXMLP(nn.Module):
|
| 280 |
+
def __init__(self, config):
|
| 281 |
+
super().__init__()
|
| 282 |
+
self.config = config
|
| 283 |
+
self.hidden_size = config.hidden_size
|
| 284 |
+
self.intermediate_size = config.intermediate_size
|
| 285 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 286 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
| 287 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
| 288 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 289 |
+
|
| 290 |
+
def forward(self, x):
|
| 291 |
+
if self.config.pretraining_tp > 1:
|
| 292 |
+
slice = self.intermediate_size // self.config.pretraining_tp
|
| 293 |
+
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
| 294 |
+
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
| 295 |
+
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
| 296 |
+
|
| 297 |
+
gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
| 298 |
+
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
| 299 |
+
|
| 300 |
+
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
| 301 |
+
down_proj = [
|
| 302 |
+
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
| 303 |
+
]
|
| 304 |
+
down_proj = sum(down_proj)
|
| 305 |
+
else:
|
| 306 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 307 |
+
|
| 308 |
+
return down_proj
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 312 |
+
"""
|
| 313 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 314 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 315 |
+
"""
|
| 316 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 317 |
+
if n_rep == 1:
|
| 318 |
+
return hidden_states
|
| 319 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 320 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class HyperCLOVAXAttention(nn.Module):
|
| 324 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 325 |
+
|
| 326 |
+
def __init__(self, config: HyperCLOVAXConfig, layer_idx: Optional[int] = None):
|
| 327 |
+
super().__init__()
|
| 328 |
+
self.config = config
|
| 329 |
+
self.layer_idx = layer_idx
|
| 330 |
+
if layer_idx is None:
|
| 331 |
+
logger.warning_once(
|
| 332 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 333 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 334 |
+
"when creating this class."
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
self.attention_dropout = config.attention_dropout
|
| 338 |
+
self.hidden_size = config.hidden_size
|
| 339 |
+
self.num_heads = config.num_attention_heads
|
| 340 |
+
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
| 341 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 342 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 343 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 344 |
+
self.rope_theta = config.rope_theta
|
| 345 |
+
self.is_causal = True
|
| 346 |
+
|
| 347 |
+
self.scaling = config.attention_multiplier
|
| 348 |
+
|
| 349 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
| 350 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 351 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 352 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
| 353 |
+
|
| 354 |
+
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
| 355 |
+
self.rotary_emb = HyperCLOVAXRotaryEmbedding(config=self.config)
|
| 356 |
+
|
| 357 |
+
def forward(
|
| 358 |
+
self,
|
| 359 |
+
hidden_states: torch.Tensor,
|
| 360 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 361 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 362 |
+
past_key_value: Optional[Cache] = None,
|
| 363 |
+
output_attentions: bool = False,
|
| 364 |
+
use_cache: bool = False,
|
| 365 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 366 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 367 |
+
**kwargs,
|
| 368 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 369 |
+
bsz, q_len, _ = hidden_states.size()
|
| 370 |
+
|
| 371 |
+
if self.config.pretraining_tp > 1:
|
| 372 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
| 373 |
+
query_slices = self.q_proj.weight.split(
|
| 374 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
| 375 |
+
)
|
| 376 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
| 377 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
| 378 |
+
|
| 379 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 380 |
+
query_states = torch.cat(query_states, dim=-1)
|
| 381 |
+
|
| 382 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 383 |
+
key_states = torch.cat(key_states, dim=-1)
|
| 384 |
+
|
| 385 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 386 |
+
value_states = torch.cat(value_states, dim=-1)
|
| 387 |
+
|
| 388 |
+
else:
|
| 389 |
+
query_states = self.q_proj(hidden_states)
|
| 390 |
+
key_states = self.k_proj(hidden_states)
|
| 391 |
+
value_states = self.v_proj(hidden_states)
|
| 392 |
+
|
| 393 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 394 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 395 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 396 |
+
|
| 397 |
+
if position_embeddings is None:
|
| 398 |
+
logger.warning_once(
|
| 399 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 400 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 401 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 402 |
+
"removed and `position_embeddings` will be mandatory."
|
| 403 |
+
)
|
| 404 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 405 |
+
else:
|
| 406 |
+
cos, sin = position_embeddings
|
| 407 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 408 |
+
|
| 409 |
+
if past_key_value is not None:
|
| 410 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 411 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 412 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 413 |
+
|
| 414 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 415 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 416 |
+
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling / math.sqrt(self.head_dim)
|
| 417 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
| 418 |
+
|
| 419 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 420 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 421 |
+
attn_weights = attn_weights + causal_mask
|
| 422 |
+
|
| 423 |
+
# upcast attention to fp32
|
| 424 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 425 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 426 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 427 |
+
|
| 428 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 429 |
+
raise ValueError(
|
| 430 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 431 |
+
f" {attn_output.size()}"
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 435 |
+
|
| 436 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
| 437 |
+
|
| 438 |
+
if self.config.pretraining_tp > 1:
|
| 439 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
| 440 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
| 441 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
| 442 |
+
else:
|
| 443 |
+
attn_output = self.o_proj(attn_output)
|
| 444 |
+
|
| 445 |
+
if not output_attentions:
|
| 446 |
+
attn_weights = None
|
| 447 |
+
|
| 448 |
+
return attn_output, attn_weights, past_key_value
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class HyperCLOVAXFlashAttention2(HyperCLOVAXAttention):
|
| 452 |
+
"""
|
| 453 |
+
HyperCLOVAX flash attention module. This module inherits from `HyperCLOVAXAttention` as the weights of the module stays
|
| 454 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 455 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 456 |
+
"""
|
| 457 |
+
|
| 458 |
+
def __init__(self, *args, **kwargs):
|
| 459 |
+
super().__init__(*args, **kwargs)
|
| 460 |
+
|
| 461 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 462 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 463 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 464 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 465 |
+
|
| 466 |
+
def forward(
|
| 467 |
+
self,
|
| 468 |
+
hidden_states: torch.Tensor,
|
| 469 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 470 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 471 |
+
past_key_value: Optional[Cache] = None,
|
| 472 |
+
output_attentions: bool = False,
|
| 473 |
+
use_cache: bool = False,
|
| 474 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 475 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 476 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 477 |
+
if isinstance(past_key_value, StaticCache):
|
| 478 |
+
raise ValueError(
|
| 479 |
+
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
| 480 |
+
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
output_attentions = False
|
| 484 |
+
|
| 485 |
+
bsz, q_len, _ = hidden_states.size()
|
| 486 |
+
|
| 487 |
+
query_states = self.q_proj(hidden_states)
|
| 488 |
+
key_states = self.k_proj(hidden_states)
|
| 489 |
+
value_states = self.v_proj(hidden_states)
|
| 490 |
+
|
| 491 |
+
# Flash attention requires the input to have the shape
|
| 492 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 493 |
+
# therefore we just need to keep the original shape
|
| 494 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 495 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 496 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 497 |
+
|
| 498 |
+
if position_embeddings is None:
|
| 499 |
+
logger.warning_once(
|
| 500 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 501 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 502 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 503 |
+
"removed and `position_embeddings` will be mandatory."
|
| 504 |
+
)
|
| 505 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 506 |
+
else:
|
| 507 |
+
cos, sin = position_embeddings
|
| 508 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 509 |
+
|
| 510 |
+
if past_key_value is not None:
|
| 511 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 512 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 513 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 514 |
+
|
| 515 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 516 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 517 |
+
query_states = query_states.transpose(1, 2)
|
| 518 |
+
key_states = key_states.transpose(1, 2)
|
| 519 |
+
value_states = value_states.transpose(1, 2)
|
| 520 |
+
|
| 521 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
| 522 |
+
|
| 523 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 524 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 525 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 526 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 527 |
+
# in fp32. (HyperCLOVAXRMSNorm handles it correctly)
|
| 528 |
+
|
| 529 |
+
input_dtype = query_states.dtype
|
| 530 |
+
if input_dtype == torch.float32:
|
| 531 |
+
if torch.is_autocast_enabled():
|
| 532 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 533 |
+
# Handle the case where the model is quantized
|
| 534 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 535 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 536 |
+
else:
|
| 537 |
+
target_dtype = self.q_proj.weight.dtype
|
| 538 |
+
|
| 539 |
+
logger.warning_once(
|
| 540 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 541 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 542 |
+
f" {target_dtype}."
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
query_states = query_states.to(target_dtype)
|
| 546 |
+
key_states = key_states.to(target_dtype)
|
| 547 |
+
value_states = value_states.to(target_dtype)
|
| 548 |
+
|
| 549 |
+
attn_output = _flash_attention_forward(
|
| 550 |
+
query_states,
|
| 551 |
+
key_states,
|
| 552 |
+
value_states,
|
| 553 |
+
attention_mask,
|
| 554 |
+
q_len,
|
| 555 |
+
position_ids=position_ids,
|
| 556 |
+
dropout=dropout_rate,
|
| 557 |
+
softmax_scale=self.scaling, # mup
|
| 558 |
+
sliding_window=getattr(self, "sliding_window", None),
|
| 559 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 560 |
+
is_causal=self.is_causal,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
| 564 |
+
attn_output = self.o_proj(attn_output)
|
| 565 |
+
|
| 566 |
+
if not output_attentions:
|
| 567 |
+
attn_weights = None
|
| 568 |
+
|
| 569 |
+
return attn_output, attn_weights, past_key_value
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class HyperCLOVAXSdpaAttention(HyperCLOVAXAttention):
|
| 573 |
+
"""
|
| 574 |
+
HyperCLOVAX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 575 |
+
`HyperCLOVAXAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 576 |
+
SDPA API.
|
| 577 |
+
"""
|
| 578 |
+
|
| 579 |
+
# Adapted from HyperCLOVAXAttention.forward
|
| 580 |
+
def forward(
|
| 581 |
+
self,
|
| 582 |
+
hidden_states: torch.Tensor,
|
| 583 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 584 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 585 |
+
past_key_value: Optional[Cache] = None,
|
| 586 |
+
output_attentions: bool = False,
|
| 587 |
+
use_cache: bool = False,
|
| 588 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 589 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 590 |
+
**kwargs,
|
| 591 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 592 |
+
if output_attentions:
|
| 593 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 594 |
+
logger.warning_once(
|
| 595 |
+
"HyperCLOVAXModel is using HyperCLOVAXSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 596 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 597 |
+
)
|
| 598 |
+
return super().forward(
|
| 599 |
+
hidden_states=hidden_states,
|
| 600 |
+
attention_mask=attention_mask,
|
| 601 |
+
position_ids=position_ids,
|
| 602 |
+
past_key_value=past_key_value,
|
| 603 |
+
output_attentions=output_attentions,
|
| 604 |
+
use_cache=use_cache,
|
| 605 |
+
cache_position=cache_position,
|
| 606 |
+
position_embeddings=position_embeddings,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
bsz, q_len, _ = hidden_states.size()
|
| 610 |
+
|
| 611 |
+
query_states = self.q_proj(hidden_states)
|
| 612 |
+
key_states = self.k_proj(hidden_states)
|
| 613 |
+
value_states = self.v_proj(hidden_states)
|
| 614 |
+
|
| 615 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 616 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 617 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 618 |
+
|
| 619 |
+
if position_embeddings is None:
|
| 620 |
+
logger.warning_once(
|
| 621 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 622 |
+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
| 623 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
| 624 |
+
"removed and `position_embeddings` will be mandatory."
|
| 625 |
+
)
|
| 626 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 627 |
+
else:
|
| 628 |
+
cos, sin = position_embeddings
|
| 629 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 630 |
+
|
| 631 |
+
if past_key_value is not None:
|
| 632 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 633 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 634 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 635 |
+
|
| 636 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 637 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 638 |
+
|
| 639 |
+
causal_mask = attention_mask
|
| 640 |
+
if attention_mask is not None:
|
| 641 |
+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
| 642 |
+
|
| 643 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 644 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 645 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 646 |
+
query_states = query_states.contiguous()
|
| 647 |
+
key_states = key_states.contiguous()
|
| 648 |
+
value_states = value_states.contiguous()
|
| 649 |
+
|
| 650 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 651 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 652 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
| 653 |
+
|
| 654 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 655 |
+
query_states,
|
| 656 |
+
key_states,
|
| 657 |
+
value_states,
|
| 658 |
+
attn_mask=causal_mask,
|
| 659 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 660 |
+
is_causal=is_causal,
|
| 661 |
+
scale=self.scaling, # mup
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 665 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
| 666 |
+
|
| 667 |
+
attn_output = self.o_proj(attn_output)
|
| 668 |
+
|
| 669 |
+
return attn_output, None, past_key_value
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
HyperCLOVAX_ATTENTION_CLASSES = {
|
| 673 |
+
"eager": HyperCLOVAXAttention,
|
| 674 |
+
"flash_attention_2": HyperCLOVAXFlashAttention2,
|
| 675 |
+
"sdpa": HyperCLOVAXSdpaAttention,
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
class HyperCLOVAXDecoderLayer(nn.Module):
|
| 680 |
+
def __init__(self, config: HyperCLOVAXConfig, layer_idx: int):
|
| 681 |
+
super().__init__()
|
| 682 |
+
self.hidden_size = config.hidden_size
|
| 683 |
+
|
| 684 |
+
self.self_attn = HyperCLOVAX_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 685 |
+
|
| 686 |
+
self.mlp = HyperCLOVAXMLP(config)
|
| 687 |
+
self.input_layernorm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 688 |
+
self.post_attention_layernorm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 689 |
+
|
| 690 |
+
# post-norm (dual-norm)
|
| 691 |
+
self.use_post_norm = config.use_post_norm
|
| 692 |
+
if self.use_post_norm:
|
| 693 |
+
self.post_norm1 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 694 |
+
self.post_norm2 = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 695 |
+
|
| 696 |
+
self.residual_multiplier = config.residual_multiplier # mup
|
| 697 |
+
|
| 698 |
+
def forward(
|
| 699 |
+
self,
|
| 700 |
+
hidden_states: torch.Tensor,
|
| 701 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 702 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 703 |
+
past_key_value: Optional[Cache] = None,
|
| 704 |
+
output_attentions: Optional[bool] = False,
|
| 705 |
+
use_cache: Optional[bool] = False,
|
| 706 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 707 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 708 |
+
**kwargs,
|
| 709 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 710 |
+
"""
|
| 711 |
+
Args:
|
| 712 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 713 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 714 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 715 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 716 |
+
output_attentions (`bool`, *optional*):
|
| 717 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 718 |
+
returned tensors for more detail.
|
| 719 |
+
use_cache (`bool`, *optional*):
|
| 720 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 721 |
+
(see `past_key_values`).
|
| 722 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 723 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 724 |
+
Indices depicting the position of the input sequence tokens in the sequence
|
| 725 |
+
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 726 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 727 |
+
with `head_dim` being the embedding dimension of each attention head.
|
| 728 |
+
kwargs (`dict`, *optional*):
|
| 729 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 730 |
+
into the model
|
| 731 |
+
"""
|
| 732 |
+
residual = hidden_states
|
| 733 |
+
|
| 734 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 735 |
+
|
| 736 |
+
# Self Attention
|
| 737 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 738 |
+
hidden_states=hidden_states,
|
| 739 |
+
attention_mask=attention_mask,
|
| 740 |
+
position_ids=position_ids,
|
| 741 |
+
past_key_value=past_key_value,
|
| 742 |
+
output_attentions=output_attentions,
|
| 743 |
+
use_cache=use_cache,
|
| 744 |
+
cache_position=cache_position,
|
| 745 |
+
position_embeddings=position_embeddings,
|
| 746 |
+
**kwargs,
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
if self.use_post_norm:
|
| 750 |
+
hidden_states = self.post_norm1(hidden_states)
|
| 751 |
+
|
| 752 |
+
hidden_states = residual + hidden_states * self.residual_multiplier # mup
|
| 753 |
+
|
| 754 |
+
# Fully Connected
|
| 755 |
+
residual = hidden_states
|
| 756 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 757 |
+
hidden_states = self.mlp(hidden_states)
|
| 758 |
+
|
| 759 |
+
if self.use_post_norm:
|
| 760 |
+
hidden_states = self.post_norm2(hidden_states)
|
| 761 |
+
|
| 762 |
+
hidden_states = residual + hidden_states * self.residual_multiplier # mup
|
| 763 |
+
|
| 764 |
+
outputs = (hidden_states,)
|
| 765 |
+
|
| 766 |
+
if output_attentions:
|
| 767 |
+
outputs += (self_attn_weights,)
|
| 768 |
+
|
| 769 |
+
if use_cache:
|
| 770 |
+
outputs += (present_key_value,)
|
| 771 |
+
|
| 772 |
+
return outputs
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
HyperCLOVAX_START_DOCSTRING = r"""
|
| 776 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 777 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 778 |
+
etc.)
|
| 779 |
+
|
| 780 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 781 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 782 |
+
and behavior.
|
| 783 |
+
|
| 784 |
+
Parameters:
|
| 785 |
+
config ([`HyperCLOVAXConfig`]):
|
| 786 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 787 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 788 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 789 |
+
"""
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
@add_start_docstrings(
|
| 793 |
+
"The bare HyperCLOVAX Model outputting raw hidden-states without any specific head on top.",
|
| 794 |
+
HyperCLOVAX_START_DOCSTRING,
|
| 795 |
+
)
|
| 796 |
+
class HyperCLOVAXPreTrainedModel(PreTrainedModel):
|
| 797 |
+
config_class = HyperCLOVAXConfig
|
| 798 |
+
base_model_prefix = "model"
|
| 799 |
+
supports_gradient_checkpointing = True
|
| 800 |
+
_no_split_modules = ["HyperCLOVAXDecoderLayer"]
|
| 801 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 802 |
+
_supports_flash_attn_2 = True
|
| 803 |
+
_supports_sdpa = True
|
| 804 |
+
_supports_cache_class = True
|
| 805 |
+
_supports_quantized_cache = True
|
| 806 |
+
_supports_static_cache = True
|
| 807 |
+
|
| 808 |
+
def _init_weights(self, module):
|
| 809 |
+
std = self.config.initializer_range
|
| 810 |
+
if isinstance(module, nn.Linear):
|
| 811 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 812 |
+
if module.bias is not None:
|
| 813 |
+
module.bias.data.zero_()
|
| 814 |
+
elif isinstance(module, nn.Embedding):
|
| 815 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 816 |
+
if module.padding_idx is not None:
|
| 817 |
+
module.weight.data[module.padding_idx].zero_()
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
HyperCLOVAX_INPUTS_DOCSTRING = r"""
|
| 821 |
+
Args:
|
| 822 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 823 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 824 |
+
it.
|
| 825 |
+
|
| 826 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 827 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 828 |
+
|
| 829 |
+
[What are input IDs?](../glossary#input-ids)
|
| 830 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 831 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 832 |
+
|
| 833 |
+
- 1 for tokens that are **not masked**,
|
| 834 |
+
- 0 for tokens that are **masked**.
|
| 835 |
+
|
| 836 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 837 |
+
|
| 838 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 839 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 840 |
+
|
| 841 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 842 |
+
`past_key_values`).
|
| 843 |
+
|
| 844 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 845 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 846 |
+
information on the default strategy.
|
| 847 |
+
|
| 848 |
+
- 1 indicates the head is **not masked**,
|
| 849 |
+
- 0 indicates the head is **masked**.
|
| 850 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 851 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 852 |
+
config.n_positions - 1]`.
|
| 853 |
+
|
| 854 |
+
[What are position IDs?](../glossary#position-ids)
|
| 855 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 856 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 857 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 858 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 859 |
+
|
| 860 |
+
Two formats are allowed:
|
| 861 |
+
- a [`~cache_utils.Cache`] instance, see our
|
| 862 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
| 863 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 864 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 865 |
+
cache format.
|
| 866 |
+
|
| 867 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 868 |
+
legacy cache format will be returned.
|
| 869 |
+
|
| 870 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 871 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 872 |
+
of shape `(batch_size, sequence_length)`.
|
| 873 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 874 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 875 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 876 |
+
model's internal embedding lookup matrix.
|
| 877 |
+
use_cache (`bool`, *optional*):
|
| 878 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 879 |
+
`past_key_values`).
|
| 880 |
+
output_attentions (`bool`, *optional*):
|
| 881 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 882 |
+
tensors for more detail.
|
| 883 |
+
output_hidden_states (`bool`, *optional*):
|
| 884 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 885 |
+
more detail.
|
| 886 |
+
return_dict (`bool`, *optional*):
|
| 887 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 888 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 889 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 890 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 891 |
+
the complete sequence length.
|
| 892 |
+
"""
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
@add_start_docstrings(
|
| 896 |
+
"The bare HyperCLOVAX Model outputting raw hidden-states without any specific head on top.",
|
| 897 |
+
HyperCLOVAX_START_DOCSTRING,
|
| 898 |
+
)
|
| 899 |
+
class HyperCLOVAXModel(HyperCLOVAXPreTrainedModel):
|
| 900 |
+
"""
|
| 901 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HyperCLOVAXDecoderLayer`]
|
| 902 |
+
|
| 903 |
+
Args:
|
| 904 |
+
config: HyperCLOVAXConfig
|
| 905 |
+
"""
|
| 906 |
+
|
| 907 |
+
def __init__(self, config: HyperCLOVAXConfig):
|
| 908 |
+
super().__init__(config)
|
| 909 |
+
self.padding_idx = config.pad_token_id
|
| 910 |
+
self.vocab_size = config.vocab_size
|
| 911 |
+
|
| 912 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 913 |
+
self.layers = nn.ModuleList(
|
| 914 |
+
[HyperCLOVAXDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 915 |
+
)
|
| 916 |
+
self.norm = HyperCLOVAXRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 917 |
+
self.rotary_emb = HyperCLOVAXRotaryEmbedding(config=config)
|
| 918 |
+
self.gradient_checkpointing = False
|
| 919 |
+
|
| 920 |
+
# Initialize weights and apply final processing
|
| 921 |
+
self.post_init()
|
| 922 |
+
|
| 923 |
+
# mup
|
| 924 |
+
self.embedding_multiplier = config.embedding_multiplier
|
| 925 |
+
|
| 926 |
+
def get_input_embeddings(self):
|
| 927 |
+
return self.embed_tokens
|
| 928 |
+
|
| 929 |
+
def set_input_embeddings(self, value):
|
| 930 |
+
self.embed_tokens = value
|
| 931 |
+
|
| 932 |
+
@add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
|
| 933 |
+
def forward(
|
| 934 |
+
self,
|
| 935 |
+
input_ids: torch.LongTensor = None,
|
| 936 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 937 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 938 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 939 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 940 |
+
use_cache: Optional[bool] = None,
|
| 941 |
+
output_attentions: Optional[bool] = None,
|
| 942 |
+
output_hidden_states: Optional[bool] = None,
|
| 943 |
+
return_dict: Optional[bool] = None,
|
| 944 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 945 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 946 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 947 |
+
output_hidden_states = (
|
| 948 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 949 |
+
)
|
| 950 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 951 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 952 |
+
|
| 953 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 954 |
+
raise ValueError(
|
| 955 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 959 |
+
logger.warning_once(
|
| 960 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 961 |
+
)
|
| 962 |
+
use_cache = False
|
| 963 |
+
|
| 964 |
+
if inputs_embeds is None:
|
| 965 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 966 |
+
|
| 967 |
+
inputs_embeds = inputs_embeds * self.embedding_multiplier # mup
|
| 968 |
+
|
| 969 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 970 |
+
return_legacy_cache = False
|
| 971 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 972 |
+
return_legacy_cache = True
|
| 973 |
+
if past_key_values is None:
|
| 974 |
+
past_key_values = DynamicCache()
|
| 975 |
+
else:
|
| 976 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 977 |
+
logger.warning_once(
|
| 978 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 979 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 980 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
if cache_position is None:
|
| 984 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 985 |
+
cache_position = torch.arange(
|
| 986 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 987 |
+
)
|
| 988 |
+
if position_ids is None:
|
| 989 |
+
position_ids = cache_position.unsqueeze(0)
|
| 990 |
+
|
| 991 |
+
causal_mask = self._update_causal_mask(
|
| 992 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 993 |
+
)
|
| 994 |
+
hidden_states = inputs_embeds
|
| 995 |
+
|
| 996 |
+
# create position embeddings to be shared across the decoder layers
|
| 997 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 998 |
+
|
| 999 |
+
# decoder layers
|
| 1000 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1001 |
+
all_self_attns = () if output_attentions else None
|
| 1002 |
+
next_decoder_cache = None
|
| 1003 |
+
|
| 1004 |
+
for decoder_layer in self.layers:
|
| 1005 |
+
if output_hidden_states:
|
| 1006 |
+
all_hidden_states += (hidden_states,)
|
| 1007 |
+
|
| 1008 |
+
if self.gradient_checkpointing and self.training:
|
| 1009 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1010 |
+
decoder_layer.__call__,
|
| 1011 |
+
hidden_states,
|
| 1012 |
+
causal_mask,
|
| 1013 |
+
position_ids,
|
| 1014 |
+
past_key_values,
|
| 1015 |
+
output_attentions,
|
| 1016 |
+
use_cache,
|
| 1017 |
+
cache_position,
|
| 1018 |
+
position_embeddings,
|
| 1019 |
+
)
|
| 1020 |
+
else:
|
| 1021 |
+
layer_outputs = decoder_layer(
|
| 1022 |
+
hidden_states,
|
| 1023 |
+
attention_mask=causal_mask,
|
| 1024 |
+
position_ids=position_ids,
|
| 1025 |
+
past_key_value=past_key_values,
|
| 1026 |
+
output_attentions=output_attentions,
|
| 1027 |
+
use_cache=use_cache,
|
| 1028 |
+
cache_position=cache_position,
|
| 1029 |
+
position_embeddings=position_embeddings,
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
hidden_states = layer_outputs[0]
|
| 1033 |
+
|
| 1034 |
+
if use_cache:
|
| 1035 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 1036 |
+
|
| 1037 |
+
if output_attentions:
|
| 1038 |
+
all_self_attns += (layer_outputs[1],)
|
| 1039 |
+
|
| 1040 |
+
hidden_states = self.norm(hidden_states)
|
| 1041 |
+
|
| 1042 |
+
# add hidden states from the last decoder layer
|
| 1043 |
+
if output_hidden_states:
|
| 1044 |
+
all_hidden_states += (hidden_states,)
|
| 1045 |
+
|
| 1046 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 1047 |
+
if return_legacy_cache:
|
| 1048 |
+
next_cache = next_cache.to_legacy_cache()
|
| 1049 |
+
|
| 1050 |
+
if not return_dict:
|
| 1051 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 1052 |
+
return BaseModelOutputWithPast(
|
| 1053 |
+
last_hidden_state=hidden_states,
|
| 1054 |
+
past_key_values=next_cache,
|
| 1055 |
+
hidden_states=all_hidden_states,
|
| 1056 |
+
attentions=all_self_attns,
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
def _update_causal_mask(
|
| 1060 |
+
self,
|
| 1061 |
+
attention_mask: torch.Tensor,
|
| 1062 |
+
input_tensor: torch.Tensor,
|
| 1063 |
+
cache_position: torch.Tensor,
|
| 1064 |
+
past_key_values: Cache,
|
| 1065 |
+
output_attentions: bool,
|
| 1066 |
+
):
|
| 1067 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 1068 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1069 |
+
return attention_mask
|
| 1070 |
+
return None
|
| 1071 |
+
|
| 1072 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 1073 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 1074 |
+
# to infer the attention mask.
|
| 1075 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1076 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1077 |
+
|
| 1078 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1079 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 1080 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 1081 |
+
attention_mask,
|
| 1082 |
+
inputs_embeds=input_tensor,
|
| 1083 |
+
past_key_values_length=past_seen_tokens,
|
| 1084 |
+
is_training=self.training,
|
| 1085 |
+
):
|
| 1086 |
+
return None
|
| 1087 |
+
|
| 1088 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 1089 |
+
min_dtype = torch.finfo(dtype).min
|
| 1090 |
+
sequence_length = input_tensor.shape[1]
|
| 1091 |
+
if using_static_cache:
|
| 1092 |
+
target_length = past_key_values.get_max_length()
|
| 1093 |
+
else:
|
| 1094 |
+
target_length = (
|
| 1095 |
+
attention_mask.shape[-1]
|
| 1096 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 1097 |
+
else past_seen_tokens + sequence_length + 1
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 1101 |
+
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1102 |
+
attention_mask,
|
| 1103 |
+
sequence_length=sequence_length,
|
| 1104 |
+
target_length=target_length,
|
| 1105 |
+
dtype=dtype,
|
| 1106 |
+
device=device,
|
| 1107 |
+
min_dtype=min_dtype,
|
| 1108 |
+
cache_position=cache_position,
|
| 1109 |
+
batch_size=input_tensor.shape[0],
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
if (
|
| 1113 |
+
self.config._attn_implementation == "sdpa"
|
| 1114 |
+
and attention_mask is not None
|
| 1115 |
+
and attention_mask.device.type == "cuda"
|
| 1116 |
+
and not output_attentions
|
| 1117 |
+
):
|
| 1118 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 1119 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 1120 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 1121 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 1122 |
+
|
| 1123 |
+
return causal_mask
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
class HyperCLOVAXForCausalLM(HyperCLOVAXPreTrainedModel, GenerationMixin):
|
| 1127 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1128 |
+
|
| 1129 |
+
def __init__(self, config):
|
| 1130 |
+
super().__init__(config)
|
| 1131 |
+
self.model = HyperCLOVAXModel(config)
|
| 1132 |
+
self.vocab_size = config.vocab_size
|
| 1133 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1134 |
+
|
| 1135 |
+
# Initialize weights and apply final processing
|
| 1136 |
+
self.post_init()
|
| 1137 |
+
|
| 1138 |
+
def _get_apply_liger_kernel_converter(self):
|
| 1139 |
+
return _apply_liger_kernel_to_instance
|
| 1140 |
+
|
| 1141 |
+
def get_input_embeddings(self):
|
| 1142 |
+
return self.model.embed_tokens
|
| 1143 |
+
|
| 1144 |
+
def set_input_embeddings(self, value):
|
| 1145 |
+
self.model.embed_tokens = value
|
| 1146 |
+
|
| 1147 |
+
def get_output_embeddings(self):
|
| 1148 |
+
return self.lm_head
|
| 1149 |
+
|
| 1150 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1151 |
+
self.lm_head = new_embeddings
|
| 1152 |
+
|
| 1153 |
+
def set_decoder(self, decoder):
|
| 1154 |
+
self.model = decoder
|
| 1155 |
+
|
| 1156 |
+
def get_decoder(self):
|
| 1157 |
+
return self.model
|
| 1158 |
+
|
| 1159 |
+
@add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
|
| 1160 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1161 |
+
def forward(
|
| 1162 |
+
self,
|
| 1163 |
+
input_ids: torch.LongTensor = None,
|
| 1164 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1165 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1166 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 1167 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1168 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1169 |
+
use_cache: Optional[bool] = None,
|
| 1170 |
+
output_attentions: Optional[bool] = None,
|
| 1171 |
+
output_hidden_states: Optional[bool] = None,
|
| 1172 |
+
return_dict: Optional[bool] = None,
|
| 1173 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1174 |
+
num_logits_to_keep: int = 0,
|
| 1175 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1176 |
+
r"""
|
| 1177 |
+
Args:
|
| 1178 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1179 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1180 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1181 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1182 |
+
|
| 1183 |
+
num_logits_to_keep (`int`, *optional*):
|
| 1184 |
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
| 1185 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 1186 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 1187 |
+
|
| 1188 |
+
Returns:
|
| 1189 |
+
|
| 1190 |
+
Example:
|
| 1191 |
+
|
| 1192 |
+
```python
|
| 1193 |
+
>>> from transformers import AutoTokenizer, HyperCLOVAXForCausalLM
|
| 1194 |
+
|
| 1195 |
+
>>> model = HyperCLOVAXForCausalLM.from_pretrained(YOUR_DIR)
|
| 1196 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(YOUR_DIR)
|
| 1197 |
+
|
| 1198 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1199 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1200 |
+
|
| 1201 |
+
>>> # Generate
|
| 1202 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1203 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1204 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1205 |
+
```"""
|
| 1206 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1207 |
+
output_hidden_states = (
|
| 1208 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1209 |
+
)
|
| 1210 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1211 |
+
|
| 1212 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1213 |
+
outputs = self.model(
|
| 1214 |
+
input_ids=input_ids,
|
| 1215 |
+
attention_mask=attention_mask,
|
| 1216 |
+
position_ids=position_ids,
|
| 1217 |
+
past_key_values=past_key_values,
|
| 1218 |
+
inputs_embeds=inputs_embeds,
|
| 1219 |
+
use_cache=use_cache,
|
| 1220 |
+
output_attentions=output_attentions,
|
| 1221 |
+
output_hidden_states=output_hidden_states,
|
| 1222 |
+
return_dict=return_dict,
|
| 1223 |
+
cache_position=cache_position,
|
| 1224 |
+
)
|
| 1225 |
+
|
| 1226 |
+
hidden_states = outputs[0]
|
| 1227 |
+
if self.config.pretraining_tp > 1:
|
| 1228 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
| 1229 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 1230 |
+
logits = torch.cat(logits, dim=-1)
|
| 1231 |
+
else:
|
| 1232 |
+
if labels is None and not is_torchdynamo_compiling():
|
| 1233 |
+
logger.warning_once(
|
| 1234 |
+
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
|
| 1235 |
+
)
|
| 1236 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1237 |
+
# TODO: remove the float() operation in v4.46
|
| 1238 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
| 1239 |
+
|
| 1240 |
+
logits = logits * self.config.logits_scaling # mup
|
| 1241 |
+
|
| 1242 |
+
loss = None
|
| 1243 |
+
if labels is not None:
|
| 1244 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
| 1245 |
+
logits = logits.float()
|
| 1246 |
+
# Shift so that tokens < n predict n
|
| 1247 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1248 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1249 |
+
# Flatten the tokens
|
| 1250 |
+
loss_fct = CrossEntropyLoss()
|
| 1251 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1252 |
+
shift_labels = shift_labels.view(-1)
|
| 1253 |
+
# Enable model parallelism
|
| 1254 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 1255 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 1256 |
+
|
| 1257 |
+
if not return_dict:
|
| 1258 |
+
output = (logits,) + outputs[1:]
|
| 1259 |
+
return (loss,) + output if loss is not None else output
|
| 1260 |
+
|
| 1261 |
+
return CausalLMOutputWithPast(
|
| 1262 |
+
loss=loss,
|
| 1263 |
+
logits=logits,
|
| 1264 |
+
past_key_values=outputs.past_key_values,
|
| 1265 |
+
hidden_states=outputs.hidden_states,
|
| 1266 |
+
attentions=outputs.attentions,
|
| 1267 |
+
)
|
| 1268 |
+
|
| 1269 |
+
def prepare_inputs_for_generation(
|
| 1270 |
+
self,
|
| 1271 |
+
input_ids,
|
| 1272 |
+
past_key_values=None,
|
| 1273 |
+
attention_mask=None,
|
| 1274 |
+
inputs_embeds=None,
|
| 1275 |
+
cache_position=None,
|
| 1276 |
+
position_ids=None,
|
| 1277 |
+
use_cache=True,
|
| 1278 |
+
num_logits_to_keep=None,
|
| 1279 |
+
**kwargs,
|
| 1280 |
+
):
|
| 1281 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
| 1282 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
| 1283 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
| 1284 |
+
if past_key_values is not None:
|
| 1285 |
+
if inputs_embeds is not None: # Exception 1
|
| 1286 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
| 1287 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
| 1288 |
+
input_ids = input_ids[:, cache_position]
|
| 1289 |
+
|
| 1290 |
+
if attention_mask is not None and position_ids is None:
|
| 1291 |
+
# create position_ids on the fly for batch generation
|
| 1292 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1293 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1294 |
+
if past_key_values:
|
| 1295 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1296 |
+
|
| 1297 |
+
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
| 1298 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
| 1299 |
+
|
| 1300 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1301 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
| 1302 |
+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
| 1303 |
+
else:
|
| 1304 |
+
# The clone here is for the same reason as for `position_ids`.
|
| 1305 |
+
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
| 1306 |
+
|
| 1307 |
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
| 1308 |
+
if model_inputs["inputs_embeds"] is not None:
|
| 1309 |
+
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
| 1310 |
+
device = model_inputs["inputs_embeds"].device
|
| 1311 |
+
else:
|
| 1312 |
+
batch_size, sequence_length = model_inputs["input_ids"].shape
|
| 1313 |
+
device = model_inputs["input_ids"].device
|
| 1314 |
+
|
| 1315 |
+
dtype = self.lm_head.weight.dtype
|
| 1316 |
+
min_dtype = torch.finfo(dtype).min
|
| 1317 |
+
|
| 1318 |
+
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1319 |
+
attention_mask,
|
| 1320 |
+
sequence_length=sequence_length,
|
| 1321 |
+
target_length=past_key_values.get_max_length(),
|
| 1322 |
+
dtype=dtype,
|
| 1323 |
+
device=device,
|
| 1324 |
+
min_dtype=min_dtype,
|
| 1325 |
+
cache_position=cache_position,
|
| 1326 |
+
batch_size=batch_size,
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
if num_logits_to_keep is not None:
|
| 1330 |
+
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
| 1331 |
+
|
| 1332 |
+
model_inputs.update(
|
| 1333 |
+
{
|
| 1334 |
+
"position_ids": position_ids,
|
| 1335 |
+
"cache_position": cache_position,
|
| 1336 |
+
"past_key_values": past_key_values,
|
| 1337 |
+
"use_cache": use_cache,
|
| 1338 |
+
"attention_mask": attention_mask,
|
| 1339 |
+
}
|
| 1340 |
+
)
|
| 1341 |
+
return model_inputs
|
| 1342 |
+
|
| 1343 |
+
|
| 1344 |
+
@add_start_docstrings(
|
| 1345 |
+
"""
|
| 1346 |
+
The HyperCLOVAX Model transformer with a sequence classification head on top (linear layer).
|
| 1347 |
+
|
| 1348 |
+
[`HyperCLOVAXForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
| 1349 |
+
(e.g. GPT-2) do.
|
| 1350 |
+
|
| 1351 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 1352 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 1353 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 1354 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 1355 |
+
each row of the batch).
|
| 1356 |
+
""",
|
| 1357 |
+
HyperCLOVAX_START_DOCSTRING,
|
| 1358 |
+
)
|
| 1359 |
+
class HyperCLOVAXForSequenceClassification(HyperCLOVAXPreTrainedModel):
|
| 1360 |
+
def __init__(self, config):
|
| 1361 |
+
super().__init__(config)
|
| 1362 |
+
self.num_labels = config.num_labels
|
| 1363 |
+
self.model = HyperCLOVAXModel(config)
|
| 1364 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1365 |
+
|
| 1366 |
+
# Initialize weights and apply final processing
|
| 1367 |
+
self.post_init()
|
| 1368 |
+
|
| 1369 |
+
def get_input_embeddings(self):
|
| 1370 |
+
return self.model.embed_tokens
|
| 1371 |
+
|
| 1372 |
+
def set_input_embeddings(self, value):
|
| 1373 |
+
self.model.embed_tokens = value
|
| 1374 |
+
|
| 1375 |
+
@add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
|
| 1376 |
+
def forward(
|
| 1377 |
+
self,
|
| 1378 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1379 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1380 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1381 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 1382 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1383 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1384 |
+
use_cache: Optional[bool] = None,
|
| 1385 |
+
output_attentions: Optional[bool] = None,
|
| 1386 |
+
output_hidden_states: Optional[bool] = None,
|
| 1387 |
+
return_dict: Optional[bool] = None,
|
| 1388 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
| 1389 |
+
r"""
|
| 1390 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1391 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1392 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1393 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1394 |
+
"""
|
| 1395 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1396 |
+
|
| 1397 |
+
transformer_outputs = self.model(
|
| 1398 |
+
input_ids,
|
| 1399 |
+
attention_mask=attention_mask,
|
| 1400 |
+
position_ids=position_ids,
|
| 1401 |
+
past_key_values=past_key_values,
|
| 1402 |
+
inputs_embeds=inputs_embeds,
|
| 1403 |
+
use_cache=use_cache,
|
| 1404 |
+
output_attentions=output_attentions,
|
| 1405 |
+
output_hidden_states=output_hidden_states,
|
| 1406 |
+
return_dict=return_dict,
|
| 1407 |
+
)
|
| 1408 |
+
hidden_states = transformer_outputs[0]
|
| 1409 |
+
logits = self.score(hidden_states)
|
| 1410 |
+
|
| 1411 |
+
if input_ids is not None:
|
| 1412 |
+
batch_size = input_ids.shape[0]
|
| 1413 |
+
else:
|
| 1414 |
+
batch_size = inputs_embeds.shape[0]
|
| 1415 |
+
|
| 1416 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 1417 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 1418 |
+
if self.config.pad_token_id is None:
|
| 1419 |
+
sequence_lengths = -1
|
| 1420 |
+
else:
|
| 1421 |
+
if input_ids is not None:
|
| 1422 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
| 1423 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
| 1424 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
| 1425 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
| 1426 |
+
else:
|
| 1427 |
+
sequence_lengths = -1
|
| 1428 |
+
|
| 1429 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 1430 |
+
|
| 1431 |
+
loss = None
|
| 1432 |
+
if labels is not None:
|
| 1433 |
+
labels = labels.to(logits.device)
|
| 1434 |
+
if self.config.problem_type is None:
|
| 1435 |
+
if self.num_labels == 1:
|
| 1436 |
+
self.config.problem_type = "regression"
|
| 1437 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1438 |
+
self.config.problem_type = "single_label_classification"
|
| 1439 |
+
else:
|
| 1440 |
+
self.config.problem_type = "multi_label_classification"
|
| 1441 |
+
|
| 1442 |
+
if self.config.problem_type == "regression":
|
| 1443 |
+
loss_fct = MSELoss()
|
| 1444 |
+
if self.num_labels == 1:
|
| 1445 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 1446 |
+
else:
|
| 1447 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1448 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1449 |
+
loss_fct = CrossEntropyLoss()
|
| 1450 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 1451 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1452 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1453 |
+
loss = loss_fct(pooled_logits, labels)
|
| 1454 |
+
if not return_dict:
|
| 1455 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 1456 |
+
return ((loss,) + output) if loss is not None else output
|
| 1457 |
+
|
| 1458 |
+
return SequenceClassifierOutputWithPast(
|
| 1459 |
+
loss=loss,
|
| 1460 |
+
logits=pooled_logits,
|
| 1461 |
+
past_key_values=transformer_outputs.past_key_values,
|
| 1462 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 1463 |
+
attentions=transformer_outputs.attentions,
|
| 1464 |
+
)
|
| 1465 |
+
|
| 1466 |
+
|
| 1467 |
+
@add_start_docstrings(
|
| 1468 |
+
"""
|
| 1469 |
+
The HyperCLOVAX Model transformer with a span classification head on top for extractive question-answering tasks like
|
| 1470 |
+
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1471 |
+
""",
|
| 1472 |
+
HyperCLOVAX_START_DOCSTRING,
|
| 1473 |
+
)
|
| 1474 |
+
class HyperCLOVAXForQuestionAnswering(HyperCLOVAXPreTrainedModel):
|
| 1475 |
+
base_model_prefix = "transformer"
|
| 1476 |
+
|
| 1477 |
+
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->HyperCLOVAX
|
| 1478 |
+
def __init__(self, config):
|
| 1479 |
+
super().__init__(config)
|
| 1480 |
+
self.transformer = HyperCLOVAXModel(config)
|
| 1481 |
+
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
| 1482 |
+
|
| 1483 |
+
# Initialize weights and apply final processing
|
| 1484 |
+
self.post_init()
|
| 1485 |
+
|
| 1486 |
+
def get_input_embeddings(self):
|
| 1487 |
+
return self.transformer.embed_tokens
|
| 1488 |
+
|
| 1489 |
+
def set_input_embeddings(self, value):
|
| 1490 |
+
self.transformer.embed_tokens = value
|
| 1491 |
+
|
| 1492 |
+
@add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
|
| 1493 |
+
def forward(
|
| 1494 |
+
self,
|
| 1495 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1496 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1497 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1498 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 1499 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1500 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 1501 |
+
end_positions: Optional[torch.LongTensor] = None,
|
| 1502 |
+
output_attentions: Optional[bool] = None,
|
| 1503 |
+
output_hidden_states: Optional[bool] = None,
|
| 1504 |
+
return_dict: Optional[bool] = None,
|
| 1505 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
| 1506 |
+
r"""
|
| 1507 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1508 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1509 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1510 |
+
are not taken into account for computing the loss.
|
| 1511 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1512 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1513 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1514 |
+
are not taken into account for computing the loss.
|
| 1515 |
+
"""
|
| 1516 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1517 |
+
|
| 1518 |
+
outputs = self.transformer(
|
| 1519 |
+
input_ids,
|
| 1520 |
+
attention_mask=attention_mask,
|
| 1521 |
+
position_ids=position_ids,
|
| 1522 |
+
past_key_values=past_key_values,
|
| 1523 |
+
inputs_embeds=inputs_embeds,
|
| 1524 |
+
output_attentions=output_attentions,
|
| 1525 |
+
output_hidden_states=output_hidden_states,
|
| 1526 |
+
return_dict=return_dict,
|
| 1527 |
+
)
|
| 1528 |
+
|
| 1529 |
+
sequence_output = outputs[0]
|
| 1530 |
+
|
| 1531 |
+
logits = self.qa_outputs(sequence_output)
|
| 1532 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1533 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1534 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1535 |
+
|
| 1536 |
+
total_loss = None
|
| 1537 |
+
if start_positions is not None and end_positions is not None:
|
| 1538 |
+
# If we are on multi-GPU, split add a dimension
|
| 1539 |
+
if len(start_positions.size()) > 1:
|
| 1540 |
+
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
| 1541 |
+
if len(end_positions.size()) > 1:
|
| 1542 |
+
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
| 1543 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1544 |
+
ignored_index = start_logits.size(1)
|
| 1545 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1546 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1547 |
+
|
| 1548 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1549 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1550 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1551 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1552 |
+
|
| 1553 |
+
if not return_dict:
|
| 1554 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1555 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1556 |
+
|
| 1557 |
+
return QuestionAnsweringModelOutput(
|
| 1558 |
+
loss=total_loss,
|
| 1559 |
+
start_logits=start_logits,
|
| 1560 |
+
end_logits=end_logits,
|
| 1561 |
+
hidden_states=outputs.hidden_states,
|
| 1562 |
+
attentions=outputs.attentions,
|
| 1563 |
+
)
|
| 1564 |
+
|
| 1565 |
+
|
| 1566 |
+
@add_start_docstrings(
|
| 1567 |
+
"""
|
| 1568 |
+
The HyperCLOVAX Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
| 1569 |
+
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
| 1570 |
+
""",
|
| 1571 |
+
HyperCLOVAX_START_DOCSTRING,
|
| 1572 |
+
)
|
| 1573 |
+
class HyperCLOVAXForTokenClassification(HyperCLOVAXPreTrainedModel):
|
| 1574 |
+
def __init__(self, config):
|
| 1575 |
+
super().__init__(config)
|
| 1576 |
+
self.num_labels = config.num_labels
|
| 1577 |
+
self.model = HyperCLOVAXModel(config)
|
| 1578 |
+
if getattr(config, "classifier_dropout", None) is not None:
|
| 1579 |
+
classifier_dropout = config.classifier_dropout
|
| 1580 |
+
elif getattr(config, "hidden_dropout", None) is not None:
|
| 1581 |
+
classifier_dropout = config.hidden_dropout
|
| 1582 |
+
else:
|
| 1583 |
+
classifier_dropout = 0.1
|
| 1584 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1585 |
+
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
| 1586 |
+
|
| 1587 |
+
# Initialize weights and apply final processing
|
| 1588 |
+
self.post_init()
|
| 1589 |
+
|
| 1590 |
+
def get_input_embeddings(self):
|
| 1591 |
+
return self.model.embed_tokens
|
| 1592 |
+
|
| 1593 |
+
def set_input_embeddings(self, value):
|
| 1594 |
+
self.model.embed_tokens = value
|
| 1595 |
+
|
| 1596 |
+
@add_start_docstrings_to_model_forward(HyperCLOVAX_INPUTS_DOCSTRING)
|
| 1597 |
+
def forward(
|
| 1598 |
+
self,
|
| 1599 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1600 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1601 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1602 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1603 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1604 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1605 |
+
use_cache: Optional[bool] = None,
|
| 1606 |
+
output_attentions: Optional[bool] = None,
|
| 1607 |
+
output_hidden_states: Optional[bool] = None,
|
| 1608 |
+
return_dict: Optional[bool] = None,
|
| 1609 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1610 |
+
r"""
|
| 1611 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1612 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1613 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1614 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1615 |
+
"""
|
| 1616 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1617 |
+
|
| 1618 |
+
outputs = self.model(
|
| 1619 |
+
input_ids,
|
| 1620 |
+
attention_mask=attention_mask,
|
| 1621 |
+
position_ids=position_ids,
|
| 1622 |
+
past_key_values=past_key_values,
|
| 1623 |
+
inputs_embeds=inputs_embeds,
|
| 1624 |
+
use_cache=use_cache,
|
| 1625 |
+
output_attentions=output_attentions,
|
| 1626 |
+
output_hidden_states=output_hidden_states,
|
| 1627 |
+
return_dict=return_dict,
|
| 1628 |
+
)
|
| 1629 |
+
sequence_output = outputs[0]
|
| 1630 |
+
sequence_output = self.dropout(sequence_output)
|
| 1631 |
+
logits = self.score(sequence_output)
|
| 1632 |
+
|
| 1633 |
+
loss = None
|
| 1634 |
+
if labels is not None:
|
| 1635 |
+
loss_fct = CrossEntropyLoss()
|
| 1636 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1637 |
+
|
| 1638 |
+
if not return_dict:
|
| 1639 |
+
output = (logits,) + outputs[2:]
|
| 1640 |
+
return ((loss,) + output) if loss is not None else output
|
| 1641 |
+
|
| 1642 |
+
return TokenClassifierOutput(
|
| 1643 |
+
loss=loss,
|
| 1644 |
+
logits=logits,
|
| 1645 |
+
hidden_states=outputs.hidden_states,
|
| 1646 |
+
attentions=outputs.attentions,
|
| 1647 |
+
)
|
| 1648 |
+
|
| 1649 |
+
|
| 1650 |
+
################################################################################################
|
| 1651 |
+
################################################################################################
|
| 1652 |
+
"""
|
| 1653 |
+
liger kernel monkey patching
|
| 1654 |
+
https://github.com/linkedin/Liger-Kernel/blob/v0.5.2/src/liger_kernel/transformers/monkey_patch.py
|
| 1655 |
+
"""
|
| 1656 |
+
|
| 1657 |
+
import inspect
|
| 1658 |
+
import logging
|
| 1659 |
+
from functools import partial
|
| 1660 |
+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
|
| 1661 |
+
|
| 1662 |
+
import torch
|
| 1663 |
+
import torch.nn.functional as F
|
| 1664 |
+
import transformers
|
| 1665 |
+
from packaging import version
|
| 1666 |
+
from torch.nn import CrossEntropyLoss
|
| 1667 |
+
from transformers import PreTrainedModel
|
| 1668 |
+
|
| 1669 |
+
if TYPE_CHECKING:
|
| 1670 |
+
from transformers.cache_utils import Cache
|
| 1671 |
+
|
| 1672 |
+
import sys
|
| 1673 |
+
|
| 1674 |
+
from packaging.version import parse
|
| 1675 |
+
|
| 1676 |
+
if sys.version_info < (3, 8):
|
| 1677 |
+
import importlib_metadata
|
| 1678 |
+
else:
|
| 1679 |
+
import importlib.metadata as importlib_metadata
|
| 1680 |
+
|
| 1681 |
+
try:
|
| 1682 |
+
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
| 1683 |
+
from liger_kernel.transformers.functional import liger_cross_entropy
|
| 1684 |
+
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
| 1685 |
+
LigerFusedLinearCrossEntropyLoss,
|
| 1686 |
+
)
|
| 1687 |
+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
| 1688 |
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
| 1689 |
+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
| 1690 |
+
|
| 1691 |
+
_is_liger_kernel_available = True
|
| 1692 |
+
|
| 1693 |
+
LIGER_KERNEL_MATCHING_VERSION = parse("0.5.2")
|
| 1694 |
+
liger_kernel_version = parse(importlib_metadata.version("liger_kernel"))
|
| 1695 |
+
_is_liger_kernel_version_matching = (
|
| 1696 |
+
liger_kernel_version.major,
|
| 1697 |
+
liger_kernel_version.minor,
|
| 1698 |
+
liger_kernel_version.release[-1],
|
| 1699 |
+
) == (
|
| 1700 |
+
LIGER_KERNEL_MATCHING_VERSION.major,
|
| 1701 |
+
LIGER_KERNEL_MATCHING_VERSION.minor,
|
| 1702 |
+
LIGER_KERNEL_MATCHING_VERSION.release[-1],
|
| 1703 |
+
)
|
| 1704 |
+
except Exception:
|
| 1705 |
+
_is_liger_kernel_available = False
|
| 1706 |
+
_is_liger_kernel_version_matching = False
|
| 1707 |
+
|
| 1708 |
+
|
| 1709 |
+
def lce_forward_deprecated(
|
| 1710 |
+
self,
|
| 1711 |
+
input_ids: torch.LongTensor = None,
|
| 1712 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1713 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1714 |
+
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
|
| 1715 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1716 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1717 |
+
use_cache: Optional[bool] = None,
|
| 1718 |
+
output_attentions: Optional[bool] = None,
|
| 1719 |
+
output_hidden_states: Optional[bool] = None,
|
| 1720 |
+
return_dict: Optional[bool] = None,
|
| 1721 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1722 |
+
num_logits_to_keep: int = 0,
|
| 1723 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1724 |
+
|
| 1725 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1726 |
+
output_hidden_states = (
|
| 1727 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1728 |
+
)
|
| 1729 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1730 |
+
|
| 1731 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1732 |
+
outputs = self.model(
|
| 1733 |
+
input_ids=input_ids,
|
| 1734 |
+
attention_mask=attention_mask,
|
| 1735 |
+
position_ids=position_ids,
|
| 1736 |
+
past_key_values=past_key_values,
|
| 1737 |
+
inputs_embeds=inputs_embeds,
|
| 1738 |
+
use_cache=use_cache,
|
| 1739 |
+
output_attentions=output_attentions,
|
| 1740 |
+
output_hidden_states=output_hidden_states,
|
| 1741 |
+
return_dict=return_dict,
|
| 1742 |
+
cache_position=cache_position,
|
| 1743 |
+
)
|
| 1744 |
+
hidden_states = outputs[0]
|
| 1745 |
+
|
| 1746 |
+
loss = None
|
| 1747 |
+
logits = None
|
| 1748 |
+
|
| 1749 |
+
if self.training and (labels is not None):
|
| 1750 |
+
if num_logits_to_keep != 0:
|
| 1751 |
+
hidden_states = hidden_states[:, -num_logits_to_keep:, :] # not sure if it has bug
|
| 1752 |
+
hidden_states = hidden_states * self.config.logits_scaling ## muP
|
| 1753 |
+
|
| 1754 |
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
| 1755 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1756 |
+
|
| 1757 |
+
# flatten tokens
|
| 1758 |
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
| 1759 |
+
shift_labels = shift_labels.view(-1)
|
| 1760 |
+
|
| 1761 |
+
lce = LigerFusedLinearCrossEntropyLoss()
|
| 1762 |
+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
| 1763 |
+
|
| 1764 |
+
else:
|
| 1765 |
+
assert self.config.pretraining_tp == 1, "not supported"
|
| 1766 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
| 1767 |
+
logits = logits * self.config.logits_scaling ## muP
|
| 1768 |
+
|
| 1769 |
+
if labels is not None:
|
| 1770 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
| 1771 |
+
logits = logits.float()
|
| 1772 |
+
# Shift so that tokens < n predict n
|
| 1773 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1774 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1775 |
+
# Flatten the tokens
|
| 1776 |
+
loss_fct = CrossEntropyLoss()
|
| 1777 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1778 |
+
shift_labels = shift_labels.view(-1)
|
| 1779 |
+
# Enable model parallelism
|
| 1780 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 1781 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 1782 |
+
|
| 1783 |
+
if not return_dict:
|
| 1784 |
+
output = (logits,) + outputs[1:]
|
| 1785 |
+
return (loss,) + output if loss is not None else output
|
| 1786 |
+
|
| 1787 |
+
return CausalLMOutputWithPast(
|
| 1788 |
+
loss=loss,
|
| 1789 |
+
logits=logits,
|
| 1790 |
+
past_key_values=outputs.past_key_values,
|
| 1791 |
+
hidden_states=outputs.hidden_states,
|
| 1792 |
+
attentions=outputs.attentions,
|
| 1793 |
+
)
|
| 1794 |
+
|
| 1795 |
+
|
| 1796 |
+
def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
| 1797 |
+
# Binds a new method to a module instance so that self is passed as the first argument
|
| 1798 |
+
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
| 1799 |
+
|
| 1800 |
+
|
| 1801 |
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
| 1802 |
+
module.offset = offset
|
| 1803 |
+
module.casting_mode = casting_mode
|
| 1804 |
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
| 1805 |
+
module.in_place = in_place
|
| 1806 |
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
| 1807 |
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
| 1808 |
+
|
| 1809 |
+
|
| 1810 |
+
def apply_liger_kernel_to_hyperclovax(
|
| 1811 |
+
rope: bool = True,
|
| 1812 |
+
cross_entropy: bool = False,
|
| 1813 |
+
fused_linear_cross_entropy: bool = True,
|
| 1814 |
+
rms_norm: bool = True,
|
| 1815 |
+
swiglu: bool = True,
|
| 1816 |
+
model: PreTrainedModel = None,
|
| 1817 |
+
) -> None:
|
| 1818 |
+
|
| 1819 |
+
assert not cross_entropy, "not supported"
|
| 1820 |
+
if rope:
|
| 1821 |
+
apply_rotary_pos_emb = liger_rotary_pos_emb
|
| 1822 |
+
if rms_norm:
|
| 1823 |
+
HyperCLOVAXRMSNorm = LigerRMSNorm
|
| 1824 |
+
if swiglu:
|
| 1825 |
+
HyperCLOVAXMLP = LigerSwiGLUMLP
|
| 1826 |
+
# to use VLM forward in VLM repo
|
| 1827 |
+
# if fused_linear_cross_entropy:
|
| 1828 |
+
# HyperCLOVAXForCausalLM.forward = lce_forward_deprecated
|
| 1829 |
+
|
| 1830 |
+
if model is not None:
|
| 1831 |
+
# The model instance already exists, so we need to additionally patch the
|
| 1832 |
+
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
| 1833 |
+
|
| 1834 |
+
# get the base model from the model instance
|
| 1835 |
+
base_model: HyperCLOVAXModel = getattr(model, model.base_model_prefix, model)
|
| 1836 |
+
|
| 1837 |
+
if rms_norm:
|
| 1838 |
+
_patch_rms_norm_module(base_model.norm)
|
| 1839 |
+
|
| 1840 |
+
for decoder_layer in base_model.layers:
|
| 1841 |
+
if swiglu:
|
| 1842 |
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
| 1843 |
+
if rms_norm:
|
| 1844 |
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
| 1845 |
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
| 1846 |
+
if decoder_layer.use_post_norm:
|
| 1847 |
+
_patch_rms_norm_module(decoder_layer.post_norm1)
|
| 1848 |
+
_patch_rms_norm_module(decoder_layer.post_norm2)
|
| 1849 |
+
|
| 1850 |
+
|
| 1851 |
+
def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
| 1852 |
+
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
| 1853 |
+
assert model_type == "hyperclovax"
|
| 1854 |
+
apply_fn = apply_liger_kernel_to_hyperclovax
|
| 1855 |
+
apply_fn_signature = inspect.signature(apply_fn)
|
| 1856 |
+
|
| 1857 |
+
# Filter out the keyword arguments that are not supported by the apply function
|
| 1858 |
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
|
| 1859 |
+
logger.info(
|
| 1860 |
+
f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
|
| 1861 |
+
)
|
| 1862 |
+
apply_fn(model=model, **applicable_kwargs)
|
| 1863 |
+
|
| 1864 |
+
|
| 1865 |
+
################################################################################################
|
| 1866 |
+
################################################################################################
|
modeling_vlm.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
patch_vuvlm.py
ADDED
|
@@ -0,0 +1,1085 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import gc
|
| 3 |
+
import inspect
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from functools import partial
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from liger_kernel.transformers import (
|
| 15 |
+
LigerCrossEntropyLoss,
|
| 16 |
+
LigerFusedLinearCrossEntropyLoss,
|
| 17 |
+
)
|
| 18 |
+
from torch.nn import CrossEntropyLoss
|
| 19 |
+
from transformers import AutoTokenizer
|
| 20 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 21 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 22 |
+
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
|
| 23 |
+
|
| 24 |
+
from hcxvlm.models.ulysses.sp_utils import (
|
| 25 |
+
gather_outputs_and_unpad,
|
| 26 |
+
get_ulysses_sequence_parallel_group,
|
| 27 |
+
get_ulysses_sequence_parallel_rank,
|
| 28 |
+
get_ulysses_sequence_parallel_world_size,
|
| 29 |
+
slice_input_tensor,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from .configuration_vlm import HCXVisionConfig
|
| 33 |
+
from .modeling_vlm import HCXVisionForCausalLM, get_rank
|
| 34 |
+
|
| 35 |
+
extra_special_tokens = {
|
| 36 |
+
"image_token": "<|IMAGE_PAD|>",
|
| 37 |
+
"discrete_image_token": "<|DISCRETE_IMAGE_PAD|>",
|
| 38 |
+
"discrete_image_unit_0_id": "<|vision00000|>",
|
| 39 |
+
"video_token": "<|VIDEO_PAD|>",
|
| 40 |
+
"video_audio_token": "<|VIDEO_AUDIO_PAD|>",
|
| 41 |
+
"audio_token": "<|AUDIO_PAD|>",
|
| 42 |
+
"discrete_audio_token": "<|DISCRETE_AUDIO_PAD|>",
|
| 43 |
+
"discrete_audio_unit_0_id": "<|audio0000|>",
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_state_dict_into_model(model_to_load, state_dict, strict=True, start_prefix=""):
|
| 48 |
+
old_keys = []
|
| 49 |
+
new_keys = []
|
| 50 |
+
for key in state_dict.keys():
|
| 51 |
+
new_key = None
|
| 52 |
+
if "gamma" in key:
|
| 53 |
+
new_key = key.replace("gamma", "weight")
|
| 54 |
+
if "beta" in key:
|
| 55 |
+
new_key = key.replace("beta", "bias")
|
| 56 |
+
if new_key:
|
| 57 |
+
old_keys.append(key)
|
| 58 |
+
new_keys.append(new_key)
|
| 59 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
| 60 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 61 |
+
|
| 62 |
+
metadata = getattr(state_dict, "_metadata", None)
|
| 63 |
+
state_dict = state_dict.copy()
|
| 64 |
+
if metadata is not None:
|
| 65 |
+
state_dict._metadata = metadata
|
| 66 |
+
|
| 67 |
+
error_msgs = []
|
| 68 |
+
|
| 69 |
+
def load(module: nn.Module, state_dict, prefix=""):
|
| 70 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| 71 |
+
args = (state_dict, prefix, local_metadata, strict, [], [], error_msgs)
|
| 72 |
+
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
| 73 |
+
if is_deepspeed_zero3_enabled():
|
| 74 |
+
import deepspeed
|
| 75 |
+
|
| 76 |
+
named_parameters = dict(
|
| 77 |
+
module.named_parameters(prefix=prefix[:-1], recurse=False)
|
| 78 |
+
)
|
| 79 |
+
params_to_gather = [
|
| 80 |
+
named_parameters[k]
|
| 81 |
+
for k in state_dict.keys()
|
| 82 |
+
if k in named_parameters
|
| 83 |
+
]
|
| 84 |
+
if len(params_to_gather) > 0:
|
| 85 |
+
with deepspeed.zero.GatheredParameters(
|
| 86 |
+
params_to_gather, modifier_rank=0
|
| 87 |
+
):
|
| 88 |
+
if torch.distributed.get_rank() == 0:
|
| 89 |
+
module._load_from_state_dict(*args)
|
| 90 |
+
else:
|
| 91 |
+
module._load_from_state_dict(*args)
|
| 92 |
+
|
| 93 |
+
for name, child in module._modules.items():
|
| 94 |
+
if child is not None:
|
| 95 |
+
load(child, state_dict, prefix + name + ".")
|
| 96 |
+
|
| 97 |
+
load(model_to_load, state_dict, prefix=start_prefix)
|
| 98 |
+
del state_dict
|
| 99 |
+
|
| 100 |
+
return error_msgs
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_sharded_checkpoint(
|
| 104 |
+
model,
|
| 105 |
+
folder,
|
| 106 |
+
pick_prefix="",
|
| 107 |
+
replace_prefix_list=[],
|
| 108 |
+
replace_prefix_dict={},
|
| 109 |
+
print_info=True,
|
| 110 |
+
):
|
| 111 |
+
if folder is None:
|
| 112 |
+
return {}
|
| 113 |
+
|
| 114 |
+
files = os.listdir(folder)
|
| 115 |
+
|
| 116 |
+
pytorch_bin_files = [
|
| 117 |
+
file
|
| 118 |
+
for file in files
|
| 119 |
+
if file.startswith("pytorch_model") and file.endswith(".bin")
|
| 120 |
+
]
|
| 121 |
+
safetensor_files = [file for file in files if file.endswith(".safetensors")]
|
| 122 |
+
shard_index_file = [file for file in files if file.endswith(".index.json")]
|
| 123 |
+
|
| 124 |
+
index_present = len(shard_index_file) > 0
|
| 125 |
+
index_file = os.path.join(folder, shard_index_file[0]) if index_present else []
|
| 126 |
+
|
| 127 |
+
is_safetensor = len(safetensor_files) > 0
|
| 128 |
+
|
| 129 |
+
model_keys = model.state_dict().keys()
|
| 130 |
+
|
| 131 |
+
if is_safetensor:
|
| 132 |
+
from safetensors.torch import load_file
|
| 133 |
+
|
| 134 |
+
load_function = load_file
|
| 135 |
+
shard_files = safetensor_files
|
| 136 |
+
else:
|
| 137 |
+
load_function = partial(torch.load, map_location="cpu")
|
| 138 |
+
shard_files = pytorch_bin_files
|
| 139 |
+
|
| 140 |
+
if index_present:
|
| 141 |
+
with open(index_file, "r", encoding="utf-8") as f:
|
| 142 |
+
index = json.load(f)
|
| 143 |
+
loaded_keys = index["weight_map"].keys()
|
| 144 |
+
if pick_prefix:
|
| 145 |
+
loaded_keys = [
|
| 146 |
+
k[len(pick_prefix) :] for k in loaded_keys if k.startswith(pick_prefix)
|
| 147 |
+
]
|
| 148 |
+
if replace_prefix_list:
|
| 149 |
+
for rep_prefix in replace_prefix_list:
|
| 150 |
+
loaded_keys = [
|
| 151 |
+
k[len(rep_prefix) :] if k.startswith(rep_prefix) else k
|
| 152 |
+
for k in loaded_keys
|
| 153 |
+
]
|
| 154 |
+
if replace_prefix_dict:
|
| 155 |
+
for rep_prefix in replace_prefix_dict:
|
| 156 |
+
loaded_keys = [
|
| 157 |
+
(
|
| 158 |
+
k.replace(rep_prefix, replace_prefix_dict[rep_prefix])
|
| 159 |
+
if k.startswith(rep_prefix)
|
| 160 |
+
else k
|
| 161 |
+
)
|
| 162 |
+
for k in loaded_keys
|
| 163 |
+
]
|
| 164 |
+
|
| 165 |
+
for i, shard_file in enumerate(shard_files):
|
| 166 |
+
state_dict = load_function(os.path.join(folder, shard_file))
|
| 167 |
+
|
| 168 |
+
if pick_prefix:
|
| 169 |
+
state_dict = {
|
| 170 |
+
k[len(pick_prefix) :]: v
|
| 171 |
+
for k, v in state_dict.items()
|
| 172 |
+
if k.startswith(pick_prefix)
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
for rep_prefix in replace_prefix_list:
|
| 176 |
+
state_dict = {
|
| 177 |
+
k[len(rep_prefix) :] if k.startswith(rep_prefix) else k: v
|
| 178 |
+
for k, v in state_dict.items()
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
for rep_prefix in replace_prefix_dict:
|
| 182 |
+
state_dict = {
|
| 183 |
+
(
|
| 184 |
+
k.replace(rep_prefix, replace_prefix_dict[rep_prefix])
|
| 185 |
+
if k.startswith(rep_prefix)
|
| 186 |
+
else k
|
| 187 |
+
): v
|
| 188 |
+
for k, v in state_dict.items()
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
if is_deepspeed_zero3_enabled():
|
| 192 |
+
rank = torch.distributed.get_rank()
|
| 193 |
+
print(f"# [info] ZeRo3 - load sharded no {i}, rank {rank}")
|
| 194 |
+
load_state_dict_into_model(model, state_dict, strict=False)
|
| 195 |
+
elif is_fsdp_enabled():
|
| 196 |
+
if is_local_dist_rank_0():
|
| 197 |
+
model.load_state_dict(state_dict, strict=False)
|
| 198 |
+
else:
|
| 199 |
+
model.load_state_dict(state_dict, strict=False)
|
| 200 |
+
|
| 201 |
+
if not index_present:
|
| 202 |
+
loaded_keys = state_dict.keys()
|
| 203 |
+
|
| 204 |
+
del state_dict
|
| 205 |
+
gc.collect()
|
| 206 |
+
|
| 207 |
+
missing_keys = [key for key in model_keys if key not in loaded_keys]
|
| 208 |
+
unexpected_keys = [key for key in loaded_keys if key not in model_keys]
|
| 209 |
+
|
| 210 |
+
if get_rank() == 0 and print_info:
|
| 211 |
+
print(f"[info] missing_keys: {missing_keys}")
|
| 212 |
+
print(f"[info] unexpected_keys: {unexpected_keys}")
|
| 213 |
+
|
| 214 |
+
return {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class HCXVisionForCausalLM_VU(HCXVisionForCausalLM):
|
| 218 |
+
def __init__(self, config, **kwargs):
|
| 219 |
+
self.use_liger = kwargs.pop("use_liger", True)
|
| 220 |
+
self.use_fused_ce = kwargs.pop("use_fused_ce", True)
|
| 221 |
+
self.use_meansum_loss = kwargs.pop("use_meansum_loss", True)
|
| 222 |
+
self.use_turnmeansum_loss = kwargs.pop("use_turnmeansum_loss", False)
|
| 223 |
+
self.use_sqrtsum_loss = kwargs.pop("use_sqrtsum_loss", False)
|
| 224 |
+
use_sum_loss = True if kwargs.pop("use_sum_loss", False) else False
|
| 225 |
+
|
| 226 |
+
self.sequence_parallel_size = kwargs.pop("sequence_parallel_size", 1)
|
| 227 |
+
self.sp_manager = kwargs.pop("sp_manager", None)
|
| 228 |
+
self.train_video = kwargs.pop("train_video", False)
|
| 229 |
+
|
| 230 |
+
assert (
|
| 231 |
+
int(self.use_meansum_loss)
|
| 232 |
+
+ int(self.use_turnmeansum_loss)
|
| 233 |
+
+ int(self.use_sqrtsum_loss)
|
| 234 |
+
) <= 1, "use_meansum_loss, use_turnmeansum_loss, use_sqrtsum_loss 중 둘 이상을 동시에 True로 설정할 수 없습니다."
|
| 235 |
+
|
| 236 |
+
if self.use_meansum_loss or self.use_turnmeansum_loss or self.use_sqrtsum_loss:
|
| 237 |
+
self.reduction = "none"
|
| 238 |
+
elif use_sum_loss:
|
| 239 |
+
self.reduction = "sum"
|
| 240 |
+
else:
|
| 241 |
+
self.reduction = "mean"
|
| 242 |
+
|
| 243 |
+
super().__init__(config, **kwargs)
|
| 244 |
+
if config.text_config.model_type == "hyperclovax" and self.use_liger:
|
| 245 |
+
self.language_model._get_apply_liger_kernel_converter()(
|
| 246 |
+
model=self.language_model
|
| 247 |
+
)
|
| 248 |
+
print("[info] use liger kernel for hcx 24b")
|
| 249 |
+
if config.freeze_encoder:
|
| 250 |
+
for param in self.vision_model.parameters():
|
| 251 |
+
param.requires_grad = False
|
| 252 |
+
assert (
|
| 253 |
+
all(param.requires_grad for param in self.vision_model.parameters())
|
| 254 |
+
== False
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
@classmethod
|
| 258 |
+
def from_pretrained(
|
| 259 |
+
cls,
|
| 260 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
| 261 |
+
text_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
| 262 |
+
vision_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
| 263 |
+
discrete_vision_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
| 264 |
+
audio_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
| 265 |
+
discrete_audio_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
| 266 |
+
q_former_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
| 267 |
+
without_llm: bool = False,
|
| 268 |
+
*model_args,
|
| 269 |
+
**kwargs,
|
| 270 |
+
):
|
| 271 |
+
"""
|
| 272 |
+
:param pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for LLM(text_model_name_or_path) e.g. /path/to/model/
|
| 273 |
+
:param vision_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for VisionModule(HyperClova-VisionModule) e.g. /path/to/vision/module/
|
| 274 |
+
:param q_former_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for VLM e.g. /path/to/vlm/checkpoint/
|
| 275 |
+
:param without_llm: Bool: False: init/load llm weight from pre-trained True: init/load llm weight from dummy file
|
| 276 |
+
:param model_args:
|
| 277 |
+
:param kwargs:
|
| 278 |
+
:return:
|
| 279 |
+
"""
|
| 280 |
+
assert pretrained_model_name_or_path is not None or (
|
| 281 |
+
text_model_name_or_path is not None
|
| 282 |
+
and vision_model_name_or_path is not None
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
cache_dirpath = kwargs.pop("cache_dirpath", None)
|
| 286 |
+
if cache_dirpath is None:
|
| 287 |
+
cache_dirpath = "~/.cache"
|
| 288 |
+
|
| 289 |
+
runtime_only_keys = {
|
| 290 |
+
"use_liger",
|
| 291 |
+
"use_fused_ce",
|
| 292 |
+
"use_meansum_loss",
|
| 293 |
+
"use_turnmeansum_loss",
|
| 294 |
+
"use_sqrtsum_loss",
|
| 295 |
+
"use_sum_loss",
|
| 296 |
+
"sequence_parallel_size",
|
| 297 |
+
"sp_manager",
|
| 298 |
+
"train_video",
|
| 299 |
+
}
|
| 300 |
+
runtime_kwargs = {}
|
| 301 |
+
for k in list(runtime_only_keys):
|
| 302 |
+
if k in kwargs:
|
| 303 |
+
runtime_kwargs[k] = kwargs.pop(k)
|
| 304 |
+
|
| 305 |
+
kwargs["vision_model_name_or_path"] = vision_model_name_or_path
|
| 306 |
+
kwargs["discrete_vision_model_name_or_path"] = (
|
| 307 |
+
discrete_vision_model_name_or_path
|
| 308 |
+
)
|
| 309 |
+
kwargs["audio_model_name_or_path"] = audio_model_name_or_path
|
| 310 |
+
kwargs["discrete_audio_model_name_or_path"] = discrete_audio_model_name_or_path
|
| 311 |
+
|
| 312 |
+
save_only_vision = (
|
| 313 |
+
kwargs.pop("save_only_vision") if "save_only_vision" in kwargs else False
|
| 314 |
+
)
|
| 315 |
+
save_only_qformer = (
|
| 316 |
+
kwargs.pop("save_only_qformer") if "save_only_qformer" in kwargs else False
|
| 317 |
+
)
|
| 318 |
+
save_shard_size = (
|
| 319 |
+
kwargs.pop("save_shard_size") if "save_shard_size" in kwargs else "5GB"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
def _purge_runtime_from_config(cfg):
|
| 323 |
+
for rk in runtime_only_keys:
|
| 324 |
+
if hasattr(cfg, rk):
|
| 325 |
+
delattr(cfg, rk)
|
| 326 |
+
|
| 327 |
+
template_path = "hcxvlm/dataset/chat_template.jinja"
|
| 328 |
+
with open(template_path, "r", encoding="utf-8") as f:
|
| 329 |
+
chat_template_str = f.read()
|
| 330 |
+
if without_llm:
|
| 331 |
+
assert pretrained_model_name_or_path is not None and os.path.exists(
|
| 332 |
+
pretrained_model_name_or_path
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
dummy_config = HCXVisionConfig.from_pretrained(
|
| 336 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
| 337 |
+
*model_args,
|
| 338 |
+
**kwargs,
|
| 339 |
+
)
|
| 340 |
+
_purge_runtime_from_config(dummy_config)
|
| 341 |
+
dummy_config.text_config.num_hidden_layers = 0
|
| 342 |
+
dummy_config.text_config.num_attention_heads = 1
|
| 343 |
+
|
| 344 |
+
if isinstance(
|
| 345 |
+
dummy_config.vision_model_name_or_path, str
|
| 346 |
+
) and os.path.exists(dummy_config.vision_model_name_or_path):
|
| 347 |
+
vision_model_name_or_path = dummy_config.vision_model_name_or_path
|
| 348 |
+
assert isinstance(vision_model_name_or_path, str) and os.path.exists(
|
| 349 |
+
vision_model_name_or_path
|
| 350 |
+
), f"# [error] invalid vision_model_name_or_path: {vision_model_name_or_path}"
|
| 351 |
+
dummy_config.vision_model_name_or_path = vision_model_name_or_path
|
| 352 |
+
dummy_config.vision_config._name_or_path = vision_model_name_or_path
|
| 353 |
+
dummy_config.vision_config.vison_pretrained_name_or_path = (
|
| 354 |
+
vision_model_name_or_path
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
model = super().from_pretrained(
|
| 358 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
| 359 |
+
without_llm=True,
|
| 360 |
+
config=dummy_config,
|
| 361 |
+
*model_args,
|
| 362 |
+
**{**kwargs, **runtime_kwargs},
|
| 363 |
+
)
|
| 364 |
+
model.tokenizer = AutoTokenizer.from_pretrained(
|
| 365 |
+
pretrained_model_name_or_path
|
| 366 |
+
)
|
| 367 |
+
model.tokenizer.chat_template = chat_template_str
|
| 368 |
+
model.transformer = None
|
| 369 |
+
else:
|
| 370 |
+
if pretrained_model_name_or_path is not None and (
|
| 371 |
+
audio_model_name_or_path is not None
|
| 372 |
+
or discrete_audio_model_name_or_path is not None
|
| 373 |
+
or discrete_vision_model_name_or_path is not None
|
| 374 |
+
):
|
| 375 |
+
assert (
|
| 376 |
+
audio_model_name_or_path is not None
|
| 377 |
+
and discrete_audio_model_name_or_path is not None
|
| 378 |
+
and discrete_vision_model_name_or_path is not None
|
| 379 |
+
)
|
| 380 |
+
print(f"[DEBUG] image stage2 끝난 시점에서 audio 를 stage3 로 붙일때.")
|
| 381 |
+
pt_config = HCXVisionConfig.from_pretrained(
|
| 382 |
+
pretrained_model_name_or_path
|
| 383 |
+
)
|
| 384 |
+
_purge_runtime_from_config(pt_config)
|
| 385 |
+
config_dict = pt_config.to_dict()
|
| 386 |
+
config_dict["audio_model_name_or_path"] = audio_model_name_or_path
|
| 387 |
+
config_dict["discrete_audio_model_name_or_path"] = (
|
| 388 |
+
discrete_audio_model_name_or_path
|
| 389 |
+
)
|
| 390 |
+
config_dict["discrete_vision_model_name_or_path"] = (
|
| 391 |
+
discrete_vision_model_name_or_path
|
| 392 |
+
)
|
| 393 |
+
config = HCXVisionConfig.from_dict(config_dict)
|
| 394 |
+
print(f"config: {config}")
|
| 395 |
+
model = super().from_pretrained(
|
| 396 |
+
pretrained_model_name_or_path,
|
| 397 |
+
without_llm=False,
|
| 398 |
+
config=config,
|
| 399 |
+
_fast_init=False,
|
| 400 |
+
*model_args,
|
| 401 |
+
**kwargs,
|
| 402 |
+
)
|
| 403 |
+
model.tokenizer = AutoTokenizer.from_pretrained(
|
| 404 |
+
pretrained_model_name_or_path
|
| 405 |
+
)
|
| 406 |
+
model.tokenizer.chat_template = chat_template_str
|
| 407 |
+
elif isinstance(q_former_model_name_or_path, str):
|
| 408 |
+
config = HCXVisionConfig.from_dict(
|
| 409 |
+
{"text_model_name_or_path": text_model_name_or_path, **kwargs}
|
| 410 |
+
)
|
| 411 |
+
_purge_runtime_from_config(config)
|
| 412 |
+
model = super().from_pretrained(
|
| 413 |
+
q_former_model_name_or_path,
|
| 414 |
+
without_llm=False,
|
| 415 |
+
config=config,
|
| 416 |
+
_fast_init=False,
|
| 417 |
+
*model_args,
|
| 418 |
+
**{**kwargs, **runtime_kwargs},
|
| 419 |
+
)
|
| 420 |
+
model.tokenizer = AutoTokenizer.from_pretrained(
|
| 421 |
+
q_former_model_name_or_path
|
| 422 |
+
)
|
| 423 |
+
model.tokenizer.chat_template = chat_template_str
|
| 424 |
+
elif pretrained_model_name_or_path is not None:
|
| 425 |
+
config = HCXVisionConfig.from_pretrained(
|
| 426 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
| 427 |
+
)
|
| 428 |
+
_purge_runtime_from_config(config)
|
| 429 |
+
model = super().from_pretrained(
|
| 430 |
+
pretrained_model_name_or_path,
|
| 431 |
+
*model_args,
|
| 432 |
+
config=config,
|
| 433 |
+
**runtime_kwargs,
|
| 434 |
+
)
|
| 435 |
+
model.tokenizer = AutoTokenizer.from_pretrained(
|
| 436 |
+
pretrained_model_name_or_path
|
| 437 |
+
)
|
| 438 |
+
model.tokenizer.chat_template = chat_template_str
|
| 439 |
+
else:
|
| 440 |
+
config = HCXVisionConfig.from_dict(
|
| 441 |
+
{"text_model_name_or_path": text_model_name_or_path, **kwargs}
|
| 442 |
+
)
|
| 443 |
+
_purge_runtime_from_config(config)
|
| 444 |
+
model = HCXVisionForCausalLM_VU(
|
| 445 |
+
config, *model_args, **{**kwargs, **runtime_kwargs}
|
| 446 |
+
)
|
| 447 |
+
model.tokenizer = AutoTokenizer.from_pretrained(text_model_name_or_path)
|
| 448 |
+
model.tokenizer.chat_template = chat_template_str
|
| 449 |
+
model.mm_projector.apply(model._init_weights)
|
| 450 |
+
|
| 451 |
+
img_start_id = model.tokenizer.encode(
|
| 452 |
+
extra_special_tokens["image_token"], add_special_tokens=False
|
| 453 |
+
)
|
| 454 |
+
assert (
|
| 455 |
+
len(img_start_id) == 1
|
| 456 |
+
), f'{extra_special_tokens["image_token"]} was not encoded into a single special token. Encoding result: {img_start_id}'
|
| 457 |
+
model.config.img_start_id = img_start_id[0]
|
| 458 |
+
model.config.image_token_id = img_start_id[0]
|
| 459 |
+
|
| 460 |
+
video_start_id = model.tokenizer.encode(
|
| 461 |
+
extra_special_tokens["video_token"], add_special_tokens=False
|
| 462 |
+
)
|
| 463 |
+
assert (
|
| 464 |
+
len(video_start_id) == 1
|
| 465 |
+
), f"video_token was not encoded into a single special token. Encoding result: {video_start_id}"
|
| 466 |
+
model.config.video_start_id = video_start_id[0]
|
| 467 |
+
model.config.video_token_id = video_start_id[0]
|
| 468 |
+
|
| 469 |
+
video_audio_start_id = model.tokenizer.encode(
|
| 470 |
+
extra_special_tokens["video_audio_token"], add_special_tokens=False
|
| 471 |
+
)
|
| 472 |
+
assert (
|
| 473 |
+
len(video_audio_start_id) == 1
|
| 474 |
+
), f"video_audio_token was not encoded into a single special token. Encoding result: {video_audio_start_id}"
|
| 475 |
+
model.config.video_audio_start_id = video_audio_start_id[0]
|
| 476 |
+
model.config.video_audio_token_id = video_audio_start_id[0]
|
| 477 |
+
|
| 478 |
+
if (
|
| 479 |
+
audio_model_name_or_path is not None
|
| 480 |
+
or discrete_audio_model_name_or_path is not None
|
| 481 |
+
or discrete_vision_model_name_or_path is not None
|
| 482 |
+
):
|
| 483 |
+
audio_start_id = model.tokenizer.encode(
|
| 484 |
+
extra_special_tokens["audio_token"], add_special_tokens=False
|
| 485 |
+
)
|
| 486 |
+
assert (
|
| 487 |
+
len(audio_start_id) == 1
|
| 488 |
+
), f"audio_token was not encoded into a single special token. Encoding result: {audio_start_id}"
|
| 489 |
+
model.config.audio_start_id = audio_start_id[0]
|
| 490 |
+
model.config.audio_token_id = audio_start_id[0]
|
| 491 |
+
|
| 492 |
+
discrete_audio_start_id = model.tokenizer.encode(
|
| 493 |
+
extra_special_tokens["discrete_audio_token"], add_special_tokens=False
|
| 494 |
+
)
|
| 495 |
+
assert (
|
| 496 |
+
len(discrete_audio_start_id) == 1
|
| 497 |
+
), f"discrete_audio_token was not encoded into a single special token. Encoding result: {discrete_audio_start_id}"
|
| 498 |
+
model.config.discrete_audio_start_id = discrete_audio_start_id[0]
|
| 499 |
+
model.config.discrete_audio_token_id = discrete_audio_start_id[0]
|
| 500 |
+
discrete_audio_unit_0_id = model.tokenizer.encode(
|
| 501 |
+
extra_special_tokens["discrete_audio_unit_0_id"],
|
| 502 |
+
add_special_tokens=False,
|
| 503 |
+
)
|
| 504 |
+
assert (
|
| 505 |
+
len(discrete_audio_unit_0_id) == 1
|
| 506 |
+
), f'{extra_special_tokens["discrete_audio_unit_0_id"]} was not encoded into a single special token. Encoding result: {discrete_audio_unit_0_id}'
|
| 507 |
+
model.config.discrete_audio_unit_0_id = discrete_audio_unit_0_id[0]
|
| 508 |
+
|
| 509 |
+
discrete_image_start_id = model.tokenizer.encode(
|
| 510 |
+
extra_special_tokens["discrete_image_token"], add_special_tokens=False
|
| 511 |
+
)
|
| 512 |
+
assert (
|
| 513 |
+
len(discrete_image_start_id) == 1
|
| 514 |
+
), f'{extra_special_tokens["discrete_image_token"]} was not encoded into a single special token. Encoding result: {discrete_image_start_id}'
|
| 515 |
+
model.config.discrete_image_start_id = discrete_image_start_id[0]
|
| 516 |
+
model.config.discrete_image_token_id = discrete_image_start_id[0]
|
| 517 |
+
discrete_image_unit_0_id = model.tokenizer.encode(
|
| 518 |
+
extra_special_tokens["discrete_image_unit_0_id"],
|
| 519 |
+
add_special_tokens=False,
|
| 520 |
+
)
|
| 521 |
+
assert (
|
| 522 |
+
len(discrete_image_unit_0_id) == 1
|
| 523 |
+
), f'{extra_special_tokens["discrete_image_unit_0_id"]} was not encoded into a single special token. Encoding result: {discrete_image_unit_0_id}'
|
| 524 |
+
model.config.discrete_image_unit_0_id = discrete_image_unit_0_id[0]
|
| 525 |
+
|
| 526 |
+
model.save_only_vision = save_only_vision
|
| 527 |
+
model.save_only_qformer = save_only_qformer
|
| 528 |
+
model.save_shard_size = save_shard_size
|
| 529 |
+
|
| 530 |
+
if pretrained_model_name_or_path is None or (
|
| 531 |
+
pretrained_model_name_or_path is not None
|
| 532 |
+
and audio_model_name_or_path is not None
|
| 533 |
+
):
|
| 534 |
+
vision_model_name_or_path = kwargs.get("vision_model_name_or_path", None)
|
| 535 |
+
if vision_model_name_or_path is not None:
|
| 536 |
+
load_sharded_checkpoint(model.vision_model, vision_model_name_or_path)
|
| 537 |
+
if get_rank() == 0:
|
| 538 |
+
print("[info] vision model loading complete")
|
| 539 |
+
|
| 540 |
+
discrete_vision_model_name_or_path = kwargs.get(
|
| 541 |
+
"discrete_vision_model_name_or_path", None
|
| 542 |
+
)
|
| 543 |
+
if discrete_vision_model_name_or_path is not None:
|
| 544 |
+
|
| 545 |
+
model.discrete_vision_model.load_state_dict(
|
| 546 |
+
torch.load(
|
| 547 |
+
discrete_vision_model_name_or_path,
|
| 548 |
+
map_location=model.device,
|
| 549 |
+
weights_only=False,
|
| 550 |
+
)["model"]["sd"],
|
| 551 |
+
strict=True,
|
| 552 |
+
)
|
| 553 |
+
if get_rank() == 0:
|
| 554 |
+
print("[info] discrete vision model loading complete")
|
| 555 |
+
|
| 556 |
+
audio_model_name_or_path = kwargs.get("audio_model_name_or_path", None)
|
| 557 |
+
if audio_model_name_or_path is not None:
|
| 558 |
+
load_sharded_checkpoint(model.audio_model, audio_model_name_or_path)
|
| 559 |
+
if get_rank() == 0:
|
| 560 |
+
print("[info] audio model loading complete")
|
| 561 |
+
|
| 562 |
+
discrete_audio_model_name_or_path = kwargs.get(
|
| 563 |
+
"discrete_audio_model_name_or_path", None
|
| 564 |
+
)
|
| 565 |
+
if discrete_audio_model_name_or_path is not None:
|
| 566 |
+
|
| 567 |
+
model.discrete_audio_model.load_state_dict(
|
| 568 |
+
torch.load(
|
| 569 |
+
discrete_audio_model_name_or_path,
|
| 570 |
+
map_location=model.device,
|
| 571 |
+
weights_only=False,
|
| 572 |
+
),
|
| 573 |
+
strict=True,
|
| 574 |
+
)
|
| 575 |
+
if get_rank() == 0:
|
| 576 |
+
print("[info] discrete audio model loading complete")
|
| 577 |
+
|
| 578 |
+
if text_model_name_or_path is not None:
|
| 579 |
+
load_sharded_checkpoint(model.language_model, text_model_name_or_path)
|
| 580 |
+
if get_rank() == 0:
|
| 581 |
+
print("[info] text model loading complete")
|
| 582 |
+
|
| 583 |
+
if isinstance(q_former_model_name_or_path, str):
|
| 584 |
+
assert Path(
|
| 585 |
+
q_former_model_name_or_path
|
| 586 |
+
).exists(), f"# [error] given q_former_name_or_path not exist: {q_former_model_name_or_path}"
|
| 587 |
+
|
| 588 |
+
load_result = load_sharded_checkpoint(
|
| 589 |
+
model,
|
| 590 |
+
q_former_model_name_or_path,
|
| 591 |
+
replace_prefix_dict={
|
| 592 |
+
"vision_model.image_encoder.model.vision_tower": "vision_model",
|
| 593 |
+
"model": "language_model.model",
|
| 594 |
+
"lm_head.weight": "language_model.lm_head.weight",
|
| 595 |
+
},
|
| 596 |
+
print_info=False,
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
if get_rank() == 0:
|
| 600 |
+
missing_keys_summary = dict()
|
| 601 |
+
for key in load_result["missing_keys"]:
|
| 602 |
+
if key.split(".")[0] in missing_keys_summary:
|
| 603 |
+
missing_keys_summary[key.split(".")[0]] += 1
|
| 604 |
+
else:
|
| 605 |
+
missing_keys_summary[key.split(".")[0]] = 1
|
| 606 |
+
print(f"[info] missing_keys summary : {missing_keys_summary}")
|
| 607 |
+
print("[info] q_former model loading complete")
|
| 608 |
+
|
| 609 |
+
config: HCXVisionConfig = model.config
|
| 610 |
+
if config.model_type != "vlm":
|
| 611 |
+
model.config.model_type = "vlm"
|
| 612 |
+
|
| 613 |
+
return model
|
| 614 |
+
|
| 615 |
+
def _pad_sequence_for_sp(
|
| 616 |
+
self,
|
| 617 |
+
inputs_embeds: torch.Tensor,
|
| 618 |
+
labels: Optional[torch.Tensor],
|
| 619 |
+
sp_world_size: int,
|
| 620 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 621 |
+
"""
|
| 622 |
+
Ensure sequence length is divisible by the SP group size by padding on the sequence dimension.
|
| 623 |
+
Returns the possibly padded (inputs_embeds, labels).
|
| 624 |
+
"""
|
| 625 |
+
batch_size, seqlen, hidden_size = inputs_embeds.shape
|
| 626 |
+
remainder = seqlen % sp_world_size
|
| 627 |
+
if remainder != 0:
|
| 628 |
+
print(
|
| 629 |
+
f"[info] Padding sequence dimension to make it divisible by {sp_world_size}"
|
| 630 |
+
)
|
| 631 |
+
pad_len = sp_world_size - remainder
|
| 632 |
+
pad_embeds = torch.zeros(
|
| 633 |
+
(batch_size, pad_len, hidden_size),
|
| 634 |
+
dtype=inputs_embeds.dtype,
|
| 635 |
+
device=inputs_embeds.device,
|
| 636 |
+
)
|
| 637 |
+
inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1)
|
| 638 |
+
|
| 639 |
+
if labels is not None:
|
| 640 |
+
ignore_index = getattr(self.config, "ignore_index", -100)
|
| 641 |
+
pad_labels = torch.full(
|
| 642 |
+
(batch_size, pad_len),
|
| 643 |
+
fill_value=ignore_index,
|
| 644 |
+
dtype=labels.dtype,
|
| 645 |
+
device=labels.device,
|
| 646 |
+
)
|
| 647 |
+
labels = torch.cat([labels, pad_labels], dim=1)
|
| 648 |
+
|
| 649 |
+
return inputs_embeds, labels
|
| 650 |
+
|
| 651 |
+
def forward(
|
| 652 |
+
self,
|
| 653 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 654 |
+
pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
|
| 655 |
+
discrete_pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
|
| 656 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 657 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 658 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 659 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 660 |
+
labels: Optional[torch.LongTensor] = None,
|
| 661 |
+
use_cache: Optional[bool] = None,
|
| 662 |
+
output_attentions: Optional[bool] = None,
|
| 663 |
+
output_hidden_states: Optional[bool] = None,
|
| 664 |
+
return_dict: Optional[bool] = None,
|
| 665 |
+
image_sizes: Optional[List[List[List[int]]]] = None,
|
| 666 |
+
mm_query_lengths: Optional[List[List[int]]] = None,
|
| 667 |
+
non_mm_query_lengths: Optional[List[List[int]]] = None,
|
| 668 |
+
img_start_ids_list: Optional[List[List[int]]] = None,
|
| 669 |
+
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
|
| 670 |
+
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
|
| 671 |
+
first_last_frames_slows: Optional[List[List[bool]]] = None,
|
| 672 |
+
is_videos: Optional[List[List[bool]]] = None,
|
| 673 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 674 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 675 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 676 |
+
video_audio_values: Optional[torch.FloatTensor] = None,
|
| 677 |
+
video_audio_masks: Optional[torch.FloatTensor] = None,
|
| 678 |
+
audio_values: Optional[torch.FloatTensor] = None,
|
| 679 |
+
discrete_audio_values: Optional[torch.FloatTensor] = None,
|
| 680 |
+
discrete_audio_value_num_per_sample: Optional[torch.LongTensor] = None,
|
| 681 |
+
audio_masks: Optional[torch.LongTensor] = None,
|
| 682 |
+
**kwargs,
|
| 683 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 684 |
+
"""
|
| 685 |
+
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer.
|
| 686 |
+
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data.
|
| 687 |
+
:param pixel_values: List of List of 4D tensor (torch.float32)
|
| 688 |
+
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images.
|
| 689 |
+
:param past_key_values: None
|
| 690 |
+
:param inputs_embeds: None
|
| 691 |
+
:param labels: Optional[torch.int64] : [batchsize, variable (input_ids.size(1)+ num visual tokens)] visual token 들은 모두 IGNORE_INDEX
|
| 692 |
+
:param use_cache: None
|
| 693 |
+
:param output_attentions: Optional[bool] : get attention weights of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
|
| 694 |
+
:param output_hidden_states: Optional[bool] : get hidden states of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
|
| 695 |
+
:param return_dict: Optional[bool] : True - return dict, Fasle - return tensor
|
| 696 |
+
:param image_sizes: Stacked as a List of List, representing image sizes (width, height).
|
| 697 |
+
In cases where a sample contains no images, a single dummy image is included.
|
| 698 |
+
:param mm_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input.
|
| 699 |
+
In cases where a sample does not contain any images, an empty list is included.
|
| 700 |
+
:param non_mm_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch.
|
| 701 |
+
:img_start_ids_list: contains the indices of the img_start_id tokens for each sample.
|
| 702 |
+
:num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid.
|
| 703 |
+
:num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None.
|
| 704 |
+
:first_last_frames_slows: A List of List that contains the only first and last frames slow mode for each sample in a batch.
|
| 705 |
+
:is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video.
|
| 706 |
+
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
|
| 707 |
+
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder.
|
| 708 |
+
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
|
| 709 |
+
:return:
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
if self.sp_manager is not None and self.train_video:
|
| 713 |
+
sp_group = get_ulysses_sequence_parallel_group()
|
| 714 |
+
if sp_group is not None:
|
| 715 |
+
sp_rank = get_ulysses_sequence_parallel_rank(sp_group)
|
| 716 |
+
sp_world_size = get_ulysses_sequence_parallel_world_size(sp_group)
|
| 717 |
+
if sp_rank == 0:
|
| 718 |
+
payload = {
|
| 719 |
+
"input_ids": input_ids,
|
| 720 |
+
"labels": labels,
|
| 721 |
+
"pixel_values": pixel_values,
|
| 722 |
+
"image_grid_thw": image_grid_thw,
|
| 723 |
+
"pixel_values_videos": pixel_values_videos,
|
| 724 |
+
"video_grid_thw": video_grid_thw,
|
| 725 |
+
"video_audio_values": video_audio_values,
|
| 726 |
+
"video_audio_masks": video_audio_masks,
|
| 727 |
+
}
|
| 728 |
+
else:
|
| 729 |
+
payload = {
|
| 730 |
+
"input_ids": None,
|
| 731 |
+
"labels": None,
|
| 732 |
+
"pixel_values": None,
|
| 733 |
+
"image_grid_thw": None,
|
| 734 |
+
"pixel_values_videos": None,
|
| 735 |
+
"video_grid_thw": None,
|
| 736 |
+
"video_audio_values": None,
|
| 737 |
+
"video_audio_masks": None,
|
| 738 |
+
}
|
| 739 |
+
|
| 740 |
+
obj_list = [payload]
|
| 741 |
+
src_global_rank = dist.get_global_rank(sp_group, 0)
|
| 742 |
+
dist.broadcast_object_list(
|
| 743 |
+
obj_list, src=src_global_rank, group=sp_group
|
| 744 |
+
)
|
| 745 |
+
payload = obj_list[0]
|
| 746 |
+
|
| 747 |
+
if sp_rank != 0:
|
| 748 |
+
device = input_ids.device
|
| 749 |
+
|
| 750 |
+
input_ids = payload["input_ids"]
|
| 751 |
+
if isinstance(input_ids, torch.Tensor):
|
| 752 |
+
input_ids = input_ids.to(device)
|
| 753 |
+
|
| 754 |
+
labels = payload["labels"]
|
| 755 |
+
if isinstance(labels, torch.Tensor):
|
| 756 |
+
labels = labels.to(device)
|
| 757 |
+
|
| 758 |
+
image_grid_thw = payload["image_grid_thw"]
|
| 759 |
+
if isinstance(image_grid_thw, torch.Tensor):
|
| 760 |
+
image_grid_thw = image_grid_thw.to(device)
|
| 761 |
+
|
| 762 |
+
pixel_values_videos = payload["pixel_values_videos"]
|
| 763 |
+
if isinstance(pixel_values_videos, torch.Tensor):
|
| 764 |
+
pixel_values_videos = pixel_values_videos.to(device)
|
| 765 |
+
|
| 766 |
+
video_grid_thw = payload["video_grid_thw"]
|
| 767 |
+
if isinstance(video_grid_thw, torch.Tensor):
|
| 768 |
+
video_grid_thw = video_grid_thw.to(device)
|
| 769 |
+
|
| 770 |
+
video_audio_values = payload["video_audio_values"]
|
| 771 |
+
if isinstance(video_audio_values, torch.Tensor):
|
| 772 |
+
video_audio_values = video_audio_values.to(device)
|
| 773 |
+
|
| 774 |
+
video_audio_masks = payload["video_audio_masks"]
|
| 775 |
+
if isinstance(video_audio_masks, torch.Tensor):
|
| 776 |
+
video_audio_masks = video_audio_masks.to(device)
|
| 777 |
+
|
| 778 |
+
pixel_values = payload["pixel_values"]
|
| 779 |
+
if isinstance(pixel_values, torch.Tensor):
|
| 780 |
+
pixel_values = pixel_values.to(device)
|
| 781 |
+
|
| 782 |
+
attention_mask = None
|
| 783 |
+
output_attentions = (
|
| 784 |
+
output_attentions
|
| 785 |
+
if output_attentions is not None
|
| 786 |
+
else self.config.vision_config.output_attentions
|
| 787 |
+
)
|
| 788 |
+
output_hidden_states = (
|
| 789 |
+
output_hidden_states
|
| 790 |
+
if output_hidden_states is not None
|
| 791 |
+
else self.config.vision_config.output_hidden_states
|
| 792 |
+
)
|
| 793 |
+
return_dict = (
|
| 794 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
if inputs_embeds is None and past_key_values is None:
|
| 798 |
+
inputs_embeds, labels = self.model.extract_inputs_embeds(
|
| 799 |
+
input_ids=input_ids,
|
| 800 |
+
labels=labels,
|
| 801 |
+
pixel_values=pixel_values,
|
| 802 |
+
discrete_pixel_values=discrete_pixel_values,
|
| 803 |
+
past_key_values=past_key_values,
|
| 804 |
+
image_sizes=image_sizes,
|
| 805 |
+
mm_query_lengths=mm_query_lengths,
|
| 806 |
+
non_mm_query_lengths=non_mm_query_lengths,
|
| 807 |
+
img_start_ids_list=img_start_ids_list,
|
| 808 |
+
num_queries_vis_abstractors=num_queries_vis_abstractors,
|
| 809 |
+
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
|
| 810 |
+
first_last_frames_slows=first_last_frames_slows,
|
| 811 |
+
is_videos=is_videos,
|
| 812 |
+
image_grid_thw=image_grid_thw,
|
| 813 |
+
pixel_values_videos=pixel_values_videos,
|
| 814 |
+
video_grid_thw=video_grid_thw,
|
| 815 |
+
video_audio_values=video_audio_values,
|
| 816 |
+
video_audio_masks=video_audio_masks,
|
| 817 |
+
audio_values=audio_values,
|
| 818 |
+
discrete_audio_values=discrete_audio_values,
|
| 819 |
+
discrete_audio_value_num_per_sample=discrete_audio_value_num_per_sample,
|
| 820 |
+
audio_masks=audio_masks,
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
if labels is not None and labels.size(1) > 32768:
|
| 824 |
+
print(
|
| 825 |
+
f"[RANK {rank} debug] ❌ labels.size(1) > 32768. labels.size(): {labels.size()}"
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
if inputs_embeds is not None:
|
| 829 |
+
input_ids = None
|
| 830 |
+
|
| 831 |
+
import os
|
| 832 |
+
|
| 833 |
+
rank = int(os.environ.get("RANK", -1))
|
| 834 |
+
|
| 835 |
+
if inputs_embeds is not None:
|
| 836 |
+
expected_hidden_size = self.config.text_config.hidden_size
|
| 837 |
+
if inputs_embeds.shape[-1] != expected_hidden_size:
|
| 838 |
+
print(f"[RANK {rank}] ❌ inputs_embeds dimension mismatch!")
|
| 839 |
+
print(
|
| 840 |
+
f" Expected: {expected_hidden_size}, Got: {inputs_embeds.shape[-1]}"
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
if labels is not None:
|
| 844 |
+
vocab_size = self.get_input_embeddings().num_embeddings
|
| 845 |
+
valid_labels = labels[labels != -100]
|
| 846 |
+
if len(valid_labels) > 0:
|
| 847 |
+
if (valid_labels >= vocab_size).any() or (valid_labels < 0).any():
|
| 848 |
+
print(f"[RANK {rank}] ❌ CRITICAL: labels out of vocab range!")
|
| 849 |
+
print(
|
| 850 |
+
f" labels min/max: {valid_labels.min().item()}/{valid_labels.max().item()}"
|
| 851 |
+
)
|
| 852 |
+
print(f" vocab_size: {vocab_size}")
|
| 853 |
+
print(
|
| 854 |
+
f" Out-of-range count: {(valid_labels >= vocab_size).sum().item()}"
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
if attention_mask is not None and inputs_embeds is not None:
|
| 858 |
+
if attention_mask.shape[1] != inputs_embeds.shape[1]:
|
| 859 |
+
print(f"[RANK {rank}] ❌ attention_mask shape mismatch!")
|
| 860 |
+
print(
|
| 861 |
+
f" attention_mask: {attention_mask.shape}, inputs_embeds: {inputs_embeds.shape}"
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
if position_ids is not None:
|
| 865 |
+
max_position = position_ids.max().item()
|
| 866 |
+
if hasattr(self.language_model.config, "max_position_embeddings"):
|
| 867 |
+
max_allowed = self.language_model.config.max_position_embeddings
|
| 868 |
+
if max_position >= max_allowed:
|
| 869 |
+
print(f"[RANK {rank}] ❌ position_ids out of range!")
|
| 870 |
+
print(f" max_position: {max_position}, max_allowed: {max_allowed}")
|
| 871 |
+
|
| 872 |
+
if self.sp_manager is not None:
|
| 873 |
+
|
| 874 |
+
batch_size, seqlen, hidden_size = inputs_embeds.shape
|
| 875 |
+
|
| 876 |
+
sp_group = get_ulysses_sequence_parallel_group()
|
| 877 |
+
sp_world_size = get_ulysses_sequence_parallel_world_size(sp_group)
|
| 878 |
+
|
| 879 |
+
inputs_embeds, labels = self._pad_sequence_for_sp(
|
| 880 |
+
inputs_embeds, labels, sp_world_size
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
if position_ids is None:
|
| 884 |
+
position_ids = torch.arange(
|
| 885 |
+
seqlen, device=inputs_embeds.device, dtype=torch.long
|
| 886 |
+
)
|
| 887 |
+
position_ids = (
|
| 888 |
+
position_ids.unsqueeze(0).expand(batch_size, -1).contiguous()
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
inputs_embeds = slice_input_tensor(
|
| 892 |
+
inputs_embeds, 1, padding=False, group=sp_group
|
| 893 |
+
)
|
| 894 |
+
labels = slice_input_tensor(labels, 1, padding=False, group=sp_group)
|
| 895 |
+
use_cache = False
|
| 896 |
+
|
| 897 |
+
outputs = self.language_model.base_model(
|
| 898 |
+
input_ids=input_ids,
|
| 899 |
+
inputs_embeds=inputs_embeds,
|
| 900 |
+
attention_mask=attention_mask,
|
| 901 |
+
position_ids=position_ids,
|
| 902 |
+
past_key_values=past_key_values,
|
| 903 |
+
use_cache=use_cache,
|
| 904 |
+
output_attentions=output_attentions,
|
| 905 |
+
output_hidden_states=output_hidden_states,
|
| 906 |
+
return_dict=return_dict,
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
hidden_states = outputs[0]
|
| 910 |
+
hidden_states = hidden_states * self.config.text_config.logits_scaling
|
| 911 |
+
|
| 912 |
+
loss = None
|
| 913 |
+
logits = None
|
| 914 |
+
|
| 915 |
+
if labels is not None:
|
| 916 |
+
if self.use_liger and self.use_fused_ce:
|
| 917 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 918 |
+
shift_labels = shift_labels.view(-1)
|
| 919 |
+
|
| 920 |
+
hidden_states = hidden_states[..., :-1, :].contiguous()
|
| 921 |
+
hidden_states = hidden_states.view(
|
| 922 |
+
-1, self.language_model.config.hidden_size
|
| 923 |
+
).to(self.language_model.lm_head.weight.dtype)
|
| 924 |
+
|
| 925 |
+
import os
|
| 926 |
+
|
| 927 |
+
rank = int(os.environ.get("RANK", -1))
|
| 928 |
+
|
| 929 |
+
vocab_size = self.language_model.lm_head.weight.shape[0]
|
| 930 |
+
valid_labels = shift_labels[shift_labels != -100]
|
| 931 |
+
if len(valid_labels) > 0 and (
|
| 932 |
+
(valid_labels >= vocab_size).any() or (valid_labels < 0).any()
|
| 933 |
+
):
|
| 934 |
+
print(
|
| 935 |
+
f"[RANK {rank}] ❌ CRITICAL: shift_labels out of vocab range!"
|
| 936 |
+
)
|
| 937 |
+
print(
|
| 938 |
+
f" min/max: {valid_labels.min().item()}/{valid_labels.max().item()}, vocab: {vocab_size}"
|
| 939 |
+
)
|
| 940 |
+
print(
|
| 941 |
+
f" Out-of-range count: {(valid_labels >= vocab_size).sum().item()}"
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
lce = LigerFusedLinearCrossEntropyLoss(reduction=self.reduction)
|
| 945 |
+
try:
|
| 946 |
+
loss = lce(
|
| 947 |
+
self.language_model.lm_head.weight, hidden_states, shift_labels
|
| 948 |
+
)
|
| 949 |
+
except RuntimeError as e:
|
| 950 |
+
print(
|
| 951 |
+
f"[RANK {rank}] ❌ FATAL: LigerFusedLinearCrossEntropyLoss failed!"
|
| 952 |
+
)
|
| 953 |
+
print(f" Error: {e}")
|
| 954 |
+
print(
|
| 955 |
+
f" hidden_states: shape={hidden_states.shape}, dtype={hidden_states.dtype}"
|
| 956 |
+
)
|
| 957 |
+
print(
|
| 958 |
+
f" shift_labels: shape={shift_labels.shape}, unique_values={torch.unique(shift_labels).tolist()[:20]}"
|
| 959 |
+
)
|
| 960 |
+
print(
|
| 961 |
+
f" lm_head.weight: shape={self.language_model.lm_head.weight.shape}"
|
| 962 |
+
)
|
| 963 |
+
raise
|
| 964 |
+
elif self.use_liger:
|
| 965 |
+
logits = self.language_model.lm_head(hidden_states)
|
| 966 |
+
|
| 967 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 968 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 969 |
+
|
| 970 |
+
loss_fct = LigerCrossEntropyLoss(reduction=self.reduction)
|
| 971 |
+
shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
| 972 |
+
shift_labels = shift_labels.view(-1)
|
| 973 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 974 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 975 |
+
else:
|
| 976 |
+
logits = self.language_model.lm_head(hidden_states)
|
| 977 |
+
|
| 978 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 979 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 980 |
+
|
| 981 |
+
loss_fct = CrossEntropyLoss(reduction=self.reduction)
|
| 982 |
+
shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
| 983 |
+
shift_labels = shift_labels.view(-1)
|
| 984 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 985 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 986 |
+
|
| 987 |
+
if self.sp_manager is not None:
|
| 988 |
+
loss = gather_outputs_and_unpad(
|
| 989 |
+
loss, gather_dim=0, unpad_dim=0, padding_size=0, group=sp_group
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
if self.use_meansum_loss:
|
| 993 |
+
loss = loss.view(labels.size(0), -1).mean(dim=1).sum()
|
| 994 |
+
|
| 995 |
+
elif self.use_sqrtsum_loss:
|
| 996 |
+
per_token = loss.view(labels.size(0), -1)
|
| 997 |
+
per_sample_mean = per_token.mean(dim=1)
|
| 998 |
+
|
| 999 |
+
with torch.no_grad():
|
| 1000 |
+
labels_2d = labels.view(labels.size(0), -1)
|
| 1001 |
+
ignore_index = getattr(self.config, "ignore_index", -100)
|
| 1002 |
+
valid_mask = labels_2d.ne(ignore_index)
|
| 1003 |
+
valid_count = valid_mask.sum(dim=1).clamp(min=1).float()
|
| 1004 |
+
raw_w = valid_count.sqrt()
|
| 1005 |
+
w_mean = raw_w.mean().clamp(min=1e-6)
|
| 1006 |
+
norm_w = raw_w / w_mean
|
| 1007 |
+
|
| 1008 |
+
loss = (per_sample_mean * norm_w).sum()
|
| 1009 |
+
|
| 1010 |
+
elif self.use_turnmeansum_loss:
|
| 1011 |
+
with torch.no_grad():
|
| 1012 |
+
mask = shift_labels.view(labels.size(0), -1).ne(
|
| 1013 |
+
self.config.ignore_index
|
| 1014 |
+
)
|
| 1015 |
+
prev_mask = mask.roll(shifts=1, dims=1)
|
| 1016 |
+
prev_mask[:, 0] = False
|
| 1017 |
+
|
| 1018 |
+
turn_starts = mask & (~prev_mask)
|
| 1019 |
+
|
| 1020 |
+
turn_count = turn_starts.sum(dim=1).clamp(min=1).float()
|
| 1021 |
+
|
| 1022 |
+
loss = (loss.view(labels.size(0), -1).mean(dim=1) * turn_count).sum()
|
| 1023 |
+
|
| 1024 |
+
if self.sp_manager is not None:
|
| 1025 |
+
loss = loss / self.sp_manager.device_mesh.shape[1]
|
| 1026 |
+
if not return_dict:
|
| 1027 |
+
output = (logits,) + outputs[1:]
|
| 1028 |
+
return (loss,) + output if loss is not None else output
|
| 1029 |
+
|
| 1030 |
+
return CausalLMOutputWithPast(
|
| 1031 |
+
loss=loss,
|
| 1032 |
+
logits=logits,
|
| 1033 |
+
past_key_values=outputs.past_key_values,
|
| 1034 |
+
hidden_states=outputs.hidden_states,
|
| 1035 |
+
attentions=outputs.attentions,
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
def save_pretrained(
|
| 1039 |
+
self,
|
| 1040 |
+
save_directory: Union[str, os.PathLike],
|
| 1041 |
+
*args,
|
| 1042 |
+
**kwargs,
|
| 1043 |
+
):
|
| 1044 |
+
|
| 1045 |
+
state_dict = (
|
| 1046 |
+
kwargs["state_dict"]
|
| 1047 |
+
if kwargs.get("state_dict", None)
|
| 1048 |
+
else self.state_dict()
|
| 1049 |
+
)
|
| 1050 |
+
partial_state_dict = self.get_pretrained_state_dict(
|
| 1051 |
+
state_dict,
|
| 1052 |
+
)
|
| 1053 |
+
kwargs["state_dict"] = partial_state_dict
|
| 1054 |
+
kwargs["safe_serialization"] = self.is_safetensor_save
|
| 1055 |
+
kwargs.setdefault("max_shard_size", self.save_shard_size)
|
| 1056 |
+
super().save_pretrained(save_directory, *args, **kwargs)
|
| 1057 |
+
if self.is_qwen_visual:
|
| 1058 |
+
self.config.architectures = ["HCXVisionV2ForCausalLM"]
|
| 1059 |
+
else:
|
| 1060 |
+
self.config.architectures = ["HCXVisionForCausalLM"]
|
| 1061 |
+
self.config.auto_map["AutoModelForCausalLM"] = (
|
| 1062 |
+
"modeling_vlm.HCXVisionForCausalLM"
|
| 1063 |
+
)
|
| 1064 |
+
self.config.auto_map["AutoModelForSequenceClassification"] = (
|
| 1065 |
+
"modeling_vlm.HCXVisionForSequenceClassification"
|
| 1066 |
+
)
|
| 1067 |
+
self.config.save_pretrained(save_directory)
|
| 1068 |
+
|
| 1069 |
+
def get_pretrained_state_dict(self, state_dict):
|
| 1070 |
+
vision_key = "vision_model."
|
| 1071 |
+
llm_keys = ["language_model."]
|
| 1072 |
+
head_key = "lm_head."
|
| 1073 |
+
|
| 1074 |
+
for key in list(state_dict.keys()):
|
| 1075 |
+
if self.save_only_vision:
|
| 1076 |
+
for llm_key in llm_keys:
|
| 1077 |
+
if llm_key in key:
|
| 1078 |
+
state_dict.pop(key)
|
| 1079 |
+
if key.startswith(head_key):
|
| 1080 |
+
state_dict.pop(key)
|
| 1081 |
+
elif self.save_only_qformer:
|
| 1082 |
+
if f"{vision_key}" in key:
|
| 1083 |
+
state_dict.pop(key)
|
| 1084 |
+
|
| 1085 |
+
return state_dict
|
preprocessor.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_vlm.HCXVisionV2Processor"
|
| 4 |
+
},
|
| 5 |
+
"do_convert_rgb": true,
|
| 6 |
+
"do_normalize": true,
|
| 7 |
+
"do_rescale": true,
|
| 8 |
+
"do_resize": true,
|
| 9 |
+
"image_mean": [
|
| 10 |
+
0.48145466,
|
| 11 |
+
0.4578275,
|
| 12 |
+
0.40821073
|
| 13 |
+
],
|
| 14 |
+
"image_processor_type": "Qwen2VLImageProcessor",
|
| 15 |
+
"image_std": [
|
| 16 |
+
0.26862954,
|
| 17 |
+
0.26130258,
|
| 18 |
+
0.27577711
|
| 19 |
+
],
|
| 20 |
+
"max_pixels": 2073600,
|
| 21 |
+
"merge_size": 2,
|
| 22 |
+
"min_pixels": 3136,
|
| 23 |
+
"patch_size": 14,
|
| 24 |
+
"processor_class": "HCXVisionV2Processor",
|
| 25 |
+
"resample": 3,
|
| 26 |
+
"rescale_factor": 0.00392156862745098,
|
| 27 |
+
"size": {
|
| 28 |
+
"longest_edge": 12845056,
|
| 29 |
+
"shortest_edge": 3136
|
| 30 |
+
},
|
| 31 |
+
"temporal_patch_size": 2
|
| 32 |
+
}
|
processing_vlm.py
ADDED
|
@@ -0,0 +1,963 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from transformers import Qwen2_5_VLProcessor
|
| 10 |
+
from transformers.image_processing_utils import (
|
| 11 |
+
BaseImageProcessor,
|
| 12 |
+
BatchFeature,
|
| 13 |
+
get_size_dict,
|
| 14 |
+
)
|
| 15 |
+
from transformers.image_transforms import (
|
| 16 |
+
convert_to_rgb,
|
| 17 |
+
get_resize_output_image_size,
|
| 18 |
+
resize,
|
| 19 |
+
to_channel_dimension_format,
|
| 20 |
+
)
|
| 21 |
+
from transformers.image_utils import (
|
| 22 |
+
OPENAI_CLIP_MEAN,
|
| 23 |
+
OPENAI_CLIP_STD,
|
| 24 |
+
ChannelDimension,
|
| 25 |
+
ImageInput,
|
| 26 |
+
PILImageResampling,
|
| 27 |
+
get_image_size,
|
| 28 |
+
infer_channel_dimension_format,
|
| 29 |
+
is_scaled_image,
|
| 30 |
+
make_list_of_images,
|
| 31 |
+
to_numpy_array,
|
| 32 |
+
valid_images,
|
| 33 |
+
)
|
| 34 |
+
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import (
|
| 35 |
+
Qwen2_5_VLProcessorKwargs,
|
| 36 |
+
)
|
| 37 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 38 |
+
from transformers.utils import TensorType, logging
|
| 39 |
+
from transformers.video_utils import VideoInput
|
| 40 |
+
from typing_extensions import Unpack
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def determine_possible_resolutions(
|
| 46 |
+
anyres: bool, max_num_grids: int, grid_size: int, use_1x1_grid: bool = False
|
| 47 |
+
):
|
| 48 |
+
"""총 max_num_grids 이하의 possible resolution 조합을 찾아 반환합니다.
|
| 49 |
+
max_num_grids 가 예를 들어 4인 경우, 총 가능한 grid 조합은 [1x1, 1x2, 1x3, 1x4, 2x1, 2x2, 3x1, 4x1] 이고, 따라서 아래와 같이 계산됩니다.
|
| 50 |
+
>>> possible_resolutions = determine_possible_resolutions(anyres=True, max_num_grids=4, grid_size=336)
|
| 51 |
+
>>> print(possible_resolutions)
|
| 52 |
+
[[336, 336], [336, 672], [336, 1008], [336, 1344], [672, 336], [672, 672], [1008, 336], [1344, 336]]
|
| 53 |
+
"""
|
| 54 |
+
possible_resolutions = []
|
| 55 |
+
if anyres:
|
| 56 |
+
assert max_num_grids > 0
|
| 57 |
+
for i in range(1, max_num_grids + 1):
|
| 58 |
+
for j in range(1, max_num_grids + 1):
|
| 59 |
+
if i == 1 and j == 1 and not use_1x1_grid:
|
| 60 |
+
continue
|
| 61 |
+
if i * j <= max_num_grids:
|
| 62 |
+
possible_resolutions.append([i, j])
|
| 63 |
+
|
| 64 |
+
possible_resolutions = [
|
| 65 |
+
[ys * grid_size, xs * grid_size] for ys, xs in possible_resolutions
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
return possible_resolutions
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def divide_to_grids(
|
| 72 |
+
image: np.array, grid_size: int, input_data_format=None
|
| 73 |
+
) -> List[np.array]:
|
| 74 |
+
"""local image 를 (grid_size x grid_size) grid 로 divide"""
|
| 75 |
+
grids = []
|
| 76 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
| 77 |
+
for i in range(0, height, grid_size):
|
| 78 |
+
for j in range(0, width, grid_size):
|
| 79 |
+
if input_data_format == ChannelDimension.LAST:
|
| 80 |
+
grid = image[i : i + grid_size, j : j + grid_size]
|
| 81 |
+
else:
|
| 82 |
+
grid = image[:, i : i + grid_size, j : j + grid_size]
|
| 83 |
+
grids.append(grid)
|
| 84 |
+
|
| 85 |
+
return grids
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def pad(
|
| 89 |
+
image: np.array,
|
| 90 |
+
target_size: tuple,
|
| 91 |
+
background_color=(127, 127, 127),
|
| 92 |
+
input_data_format=None,
|
| 93 |
+
) -> np.array:
|
| 94 |
+
"""image 양옆, 좌우에 padding 을 하여 target_height, target_width 만큼 키움"""
|
| 95 |
+
target_height, target_width = target_size
|
| 96 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
| 97 |
+
|
| 98 |
+
result = np.empty((target_height, target_width, image.shape[2]), dtype=image.dtype)
|
| 99 |
+
for i in range(image.shape[2]):
|
| 100 |
+
result[..., i].fill(background_color[i])
|
| 101 |
+
|
| 102 |
+
paste_x = (target_width - width) // 2
|
| 103 |
+
paste_y = (target_height - height) // 2
|
| 104 |
+
|
| 105 |
+
result[paste_y : paste_y + height, paste_x : paste_x + width, :] = image
|
| 106 |
+
|
| 107 |
+
return result
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def expand2square(
|
| 111 |
+
image: np.array,
|
| 112 |
+
bboxes_dict=None,
|
| 113 |
+
background_color=(127, 127, 127),
|
| 114 |
+
input_data_format=None,
|
| 115 |
+
) -> np.array:
|
| 116 |
+
"""
|
| 117 |
+
새로운 canvas 를 만들어 두고, 거기에 이미지를 붙여넣는 방식으로 이미지를 정사각형으로 만드는 함수
|
| 118 |
+
유의할 사항은, 이미지를 붙여 넣을 때 중앙으로 붙여넣는다는 점. 양옆 또는 위아래로 PADDING 이 들어가는 형태
|
| 119 |
+
Args:
|
| 120 |
+
pil_img: numpy array
|
| 121 |
+
bboxes_dict: dict, {"ocr": NDArray shape (N, 4, 2), "html": NDArray shape (N, 4, 2), ... }
|
| 122 |
+
`[[xtl, ytl], [xtr, ytr], [xbr, ybr], [xbl, ybl]]` 형태로 박스 형태는 통일. OCR, HTML 등 다양한 박스들을 한번에 처리 가능
|
| 123 |
+
background_color: tuple, RGB
|
| 124 |
+
# >>> _img = np.ones((80, 100), dtype=np.uint8) * 100
|
| 125 |
+
# >>> _bboxes_dict = {"words": np.array([[[10, 10], [20, 10], [20, 20], [10, 20]],
|
| 126 |
+
# ... [[30, 30], [40, 30], [40, 40], [30, 40]]])}
|
| 127 |
+
# >>> _img, _bboxes_dict = expand2square(_img, _bboxes_dict, (255, 255, 255))
|
| 128 |
+
# >>> _img.shape
|
| 129 |
+
# (100, 100)
|
| 130 |
+
# >>> guessed_ocr_bboxes = np.array([[[20, 10], [30, 10], [30, 20], [20, 20]],
|
| 131 |
+
# ... [[40, 30], [50, 30], [50, 40], [40, 40]]])
|
| 132 |
+
# >>> np.testing.assert_array_almost_equal(_bboxes_dict["words"], guessed_ocr_bboxes) is None
|
| 133 |
+
# True
|
| 134 |
+
"""
|
| 135 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
| 136 |
+
if width == height:
|
| 137 |
+
return image, bboxes_dict
|
| 138 |
+
elif width > height:
|
| 139 |
+
result = np.empty((width, width, image.shape[2]), dtype=image.dtype)
|
| 140 |
+
for i in range(image.shape[2]):
|
| 141 |
+
result[..., i].fill(background_color[i])
|
| 142 |
+
|
| 143 |
+
result[(width - height) // 2 : (width - height) // 2 + height, :] = image
|
| 144 |
+
if bboxes_dict is not None:
|
| 145 |
+
for key in bboxes_dict:
|
| 146 |
+
bboxes_dict[key][:, :, 1] += (width - height) // 2
|
| 147 |
+
return result, bboxes_dict
|
| 148 |
+
else:
|
| 149 |
+
result = np.empty((height, height, image.shape[2]), dtype=image.dtype)
|
| 150 |
+
for i in range(image.shape[2]):
|
| 151 |
+
result[..., i].fill(background_color[i])
|
| 152 |
+
|
| 153 |
+
result[:, (height - width) // 2 : (height - width) // 2 + width] = image
|
| 154 |
+
if bboxes_dict is not None:
|
| 155 |
+
for key in bboxes_dict:
|
| 156 |
+
bboxes_dict[key][:, :, 0] += (height - width) // 2
|
| 157 |
+
return result, bboxes_dict
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def resize_longside(
|
| 161 |
+
image: np.array,
|
| 162 |
+
size: int,
|
| 163 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 164 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 165 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 166 |
+
):
|
| 167 |
+
"""
|
| 168 |
+
장축 길이를 size 에 맞게 resize
|
| 169 |
+
"""
|
| 170 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
| 171 |
+
|
| 172 |
+
if width == height:
|
| 173 |
+
target_height, target_width = size, size
|
| 174 |
+
elif width > height:
|
| 175 |
+
target_width = size
|
| 176 |
+
target_height = math.ceil(height / width * size)
|
| 177 |
+
else:
|
| 178 |
+
target_width = math.ceil(width / height * size)
|
| 179 |
+
target_height = size
|
| 180 |
+
|
| 181 |
+
return resize(
|
| 182 |
+
image,
|
| 183 |
+
size=(target_height, target_width),
|
| 184 |
+
resample=resample,
|
| 185 |
+
data_format=data_format,
|
| 186 |
+
input_data_format=input_data_format,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
|
| 191 |
+
"""From LLaVA-Next (https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llava_next/image_processing_llava_next.py)
|
| 192 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
| 193 |
+
This is done by calculating the effective and wasted resolution for each possible resolution.
|
| 194 |
+
The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
original_size (tuple):
|
| 198 |
+
The original size of the image in the format (height, width).
|
| 199 |
+
possible_resolutions (list):
|
| 200 |
+
A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
tuple: The best fit resolution in the format (height, width).
|
| 204 |
+
"""
|
| 205 |
+
original_height, original_width = original_size
|
| 206 |
+
best_fit = None
|
| 207 |
+
max_effective_resolution = 0
|
| 208 |
+
min_wasted_resolution = float("inf")
|
| 209 |
+
|
| 210 |
+
for height, width in possible_resolutions:
|
| 211 |
+
scale = min(width / original_width, height / original_height)
|
| 212 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(
|
| 213 |
+
original_height * scale
|
| 214 |
+
)
|
| 215 |
+
effective_resolution = min(
|
| 216 |
+
downscaled_width * downscaled_height, original_width * original_height
|
| 217 |
+
)
|
| 218 |
+
wasted_resolution = (width * height) - effective_resolution
|
| 219 |
+
|
| 220 |
+
if effective_resolution > max_effective_resolution or (
|
| 221 |
+
effective_resolution == max_effective_resolution
|
| 222 |
+
and wasted_resolution < min_wasted_resolution
|
| 223 |
+
):
|
| 224 |
+
max_effective_resolution = effective_resolution
|
| 225 |
+
min_wasted_resolution = wasted_resolution
|
| 226 |
+
best_fit = (height, width)
|
| 227 |
+
|
| 228 |
+
return best_fit
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _get_local_grids_output_size(
|
| 232 |
+
image: np.array, target_resolution: tuple, input_data_format=None
|
| 233 |
+
):
|
| 234 |
+
original_height, original_width = get_image_size(
|
| 235 |
+
image, channel_dim=input_data_format
|
| 236 |
+
)
|
| 237 |
+
target_height, target_width = target_resolution
|
| 238 |
+
|
| 239 |
+
scale_w = target_width / original_width
|
| 240 |
+
scale_h = target_height / original_height
|
| 241 |
+
|
| 242 |
+
if scale_w < scale_h:
|
| 243 |
+
new_width = target_width
|
| 244 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
| 245 |
+
else:
|
| 246 |
+
new_height = target_height
|
| 247 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
| 248 |
+
|
| 249 |
+
return new_height, new_width
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def determine_anyres_num_vision_patches(
|
| 253 |
+
num_grids,
|
| 254 |
+
image_size,
|
| 255 |
+
grid_size,
|
| 256 |
+
patch_size,
|
| 257 |
+
possible_resolutions,
|
| 258 |
+
anyres=False,
|
| 259 |
+
unpad=True,
|
| 260 |
+
num_queries_vis_abstractor=0,
|
| 261 |
+
num_queries_vis_abstractor_slow=0,
|
| 262 |
+
video=False,
|
| 263 |
+
first_last_frames_slow=False,
|
| 264 |
+
is_first_or_last_frames=False,
|
| 265 |
+
):
|
| 266 |
+
"""visual tokens 수를 계산해주는 함수"""
|
| 267 |
+
if not anyres:
|
| 268 |
+
return (
|
| 269 |
+
num_queries_vis_abstractor
|
| 270 |
+
if num_queries_vis_abstractor > 0
|
| 271 |
+
else (grid_size // patch_size) ** 2
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
if num_queries_vis_abstractor > 0:
|
| 275 |
+
num_patch_per_grid = int(num_queries_vis_abstractor**0.5)
|
| 276 |
+
else:
|
| 277 |
+
num_patch_per_grid = grid_size // patch_size
|
| 278 |
+
|
| 279 |
+
num_global_per_grid = num_patch_per_grid
|
| 280 |
+
|
| 281 |
+
height, width = select_best_resolution(image_size, possible_resolutions)
|
| 282 |
+
|
| 283 |
+
num_patch_height = (height // grid_size) * num_patch_per_grid
|
| 284 |
+
num_patch_width = (width // grid_size) * num_patch_per_grid
|
| 285 |
+
|
| 286 |
+
if unpad:
|
| 287 |
+
original_height, original_width = image_size
|
| 288 |
+
|
| 289 |
+
original_aspect_ratio = original_width / original_height
|
| 290 |
+
current_aspect_ratio = num_patch_width / num_patch_height
|
| 291 |
+
|
| 292 |
+
if original_aspect_ratio > current_aspect_ratio:
|
| 293 |
+
scale_factor = num_patch_width / original_width
|
| 294 |
+
new_height = int(original_height * scale_factor)
|
| 295 |
+
padding = (num_patch_height - new_height) // 2
|
| 296 |
+
num_patch_height = num_patch_height - padding * 2
|
| 297 |
+
else:
|
| 298 |
+
scale_factor = num_patch_height / original_height
|
| 299 |
+
new_width = int(original_width * scale_factor)
|
| 300 |
+
padding = (num_patch_width - new_width) // 2
|
| 301 |
+
num_patch_width = num_patch_width - padding * 2
|
| 302 |
+
|
| 303 |
+
num_patches = num_patch_width * num_patch_height + num_patch_height
|
| 304 |
+
else:
|
| 305 |
+
num_patches = num_patch_width * num_patch_height
|
| 306 |
+
|
| 307 |
+
if num_queries_vis_abstractor_slow > 0:
|
| 308 |
+
if first_last_frames_slow:
|
| 309 |
+
if is_first_or_last_frames:
|
| 310 |
+
num_patches += (
|
| 311 |
+
num_queries_vis_abstractor_slow - num_queries_vis_abstractor
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
|
| 315 |
+
assert unpad is False
|
| 316 |
+
|
| 317 |
+
if not video:
|
| 318 |
+
num_patches += num_global_per_grid**2
|
| 319 |
+
|
| 320 |
+
return num_patches
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class HCXVisionImageProcessor(BaseImageProcessor):
|
| 324 |
+
r"""
|
| 325 |
+
Constructs a VLM image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques for processing high resolution images.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
anyres: (bool) anyres 기능을 사용할지 안할지
|
| 329 |
+
unpad: (bool) anyres 사용시, unpad 기능 (순수 pad 영역에 해당하는 visual tokens 은 LLM input 에서 제거) 을 사용할지 안할지
|
| 330 |
+
num_queries_vis_abstractor: (int) 각 grid 에 대해서 resampler 를 사용하는 경우, visual query 수
|
| 331 |
+
possible_resolutions: (List) anyres 기능 사용시, 가능한 resolution 조합, 예: [[336, 336], [336, 672], [672, 336]]
|
| 332 |
+
patch_size: (int) ViT patch size
|
| 333 |
+
pad_to_square: (bool) 정사각형으로 padding 을 수행할지, 안할지를 결정. False 이면 정사각형이 아니기 때문에 center crop 을 거쳐 ViT 의 입력으로 들어감
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
model_input_names = ["pixel_values"]
|
| 337 |
+
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
do_resize: bool = True,
|
| 341 |
+
size: Dict[str, int] = None,
|
| 342 |
+
anyres: bool = False,
|
| 343 |
+
unpad: bool = False,
|
| 344 |
+
num_queries_vis_abstractor: int = 0,
|
| 345 |
+
possible_resolutions: List = [],
|
| 346 |
+
patch_size: int = 14,
|
| 347 |
+
pad_to_square: bool = True,
|
| 348 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 349 |
+
do_center_crop: bool = True,
|
| 350 |
+
crop_size: Dict[str, int] = None,
|
| 351 |
+
do_rescale: bool = True,
|
| 352 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 353 |
+
do_normalize: bool = True,
|
| 354 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 355 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 356 |
+
do_convert_rgb: bool = True,
|
| 357 |
+
**kwargs,
|
| 358 |
+
) -> None:
|
| 359 |
+
super().__init__(**kwargs)
|
| 360 |
+
size = size if size is not None else {"shortest_edge": 336}
|
| 361 |
+
size = get_size_dict(size, default_to_square=False)
|
| 362 |
+
crop_size = (
|
| 363 |
+
crop_size if crop_size is not None else {"height": 336, "width": 336}
|
| 364 |
+
)
|
| 365 |
+
crop_size = get_size_dict(
|
| 366 |
+
crop_size, default_to_square=True, param_name="crop_size"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
self.do_resize = do_resize
|
| 370 |
+
self.size = size
|
| 371 |
+
self.anyres = anyres
|
| 372 |
+
self.unpad = unpad
|
| 373 |
+
self.num_queries_vis_abstractor = num_queries_vis_abstractor
|
| 374 |
+
self.possible_resolutions = [
|
| 375 |
+
_resolution for _resolution in possible_resolutions
|
| 376 |
+
]
|
| 377 |
+
self.patch_size = patch_size
|
| 378 |
+
self.pad_to_square = pad_to_square
|
| 379 |
+
self.resample = resample
|
| 380 |
+
self.do_center_crop = do_center_crop
|
| 381 |
+
self.crop_size = crop_size
|
| 382 |
+
self.do_rescale = do_rescale
|
| 383 |
+
self.rescale_factor = rescale_factor
|
| 384 |
+
self.do_normalize = do_normalize
|
| 385 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
| 386 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
| 387 |
+
self.do_convert_rgb = do_convert_rgb
|
| 388 |
+
|
| 389 |
+
def resize(
|
| 390 |
+
self,
|
| 391 |
+
image: np.ndarray,
|
| 392 |
+
size: Dict[str, int],
|
| 393 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 394 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 395 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 396 |
+
**kwargs,
|
| 397 |
+
) -> np.ndarray:
|
| 398 |
+
default_to_square = True
|
| 399 |
+
if "shortest_edge" in size:
|
| 400 |
+
size = size["shortest_edge"]
|
| 401 |
+
default_to_square = False
|
| 402 |
+
elif "height" in size and "width" in size:
|
| 403 |
+
size = (size["height"], size["width"])
|
| 404 |
+
else:
|
| 405 |
+
raise ValueError(
|
| 406 |
+
"Size must contain either 'shortest_edge' or 'height' and 'width'."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
output_size = get_resize_output_image_size(
|
| 410 |
+
image,
|
| 411 |
+
size=size,
|
| 412 |
+
default_to_square=default_to_square,
|
| 413 |
+
input_data_format=input_data_format,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
return resize(
|
| 417 |
+
image,
|
| 418 |
+
size=output_size,
|
| 419 |
+
resample=resample,
|
| 420 |
+
data_format=data_format,
|
| 421 |
+
input_data_format=input_data_format,
|
| 422 |
+
**kwargs,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
def _preprocess(
|
| 426 |
+
self,
|
| 427 |
+
images: ImageInput,
|
| 428 |
+
do_resize: bool = None,
|
| 429 |
+
size: Dict[str, int] = None,
|
| 430 |
+
resample: PILImageResampling = None,
|
| 431 |
+
do_center_crop: bool = None,
|
| 432 |
+
crop_size: int = None,
|
| 433 |
+
do_rescale: bool = None,
|
| 434 |
+
rescale_factor: float = None,
|
| 435 |
+
do_normalize: bool = None,
|
| 436 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 437 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 438 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 439 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 440 |
+
) -> Image.Image:
|
| 441 |
+
images = make_list_of_images(images)
|
| 442 |
+
|
| 443 |
+
if do_resize:
|
| 444 |
+
images = [
|
| 445 |
+
self.resize(
|
| 446 |
+
image=image,
|
| 447 |
+
size=size,
|
| 448 |
+
resample=resample,
|
| 449 |
+
input_data_format=input_data_format,
|
| 450 |
+
)
|
| 451 |
+
for image in images
|
| 452 |
+
]
|
| 453 |
+
|
| 454 |
+
if do_center_crop:
|
| 455 |
+
images = [
|
| 456 |
+
self.center_crop(
|
| 457 |
+
image=image, size=crop_size, input_data_format=input_data_format
|
| 458 |
+
)
|
| 459 |
+
for image in images
|
| 460 |
+
]
|
| 461 |
+
|
| 462 |
+
if do_rescale:
|
| 463 |
+
images = [
|
| 464 |
+
self.rescale(
|
| 465 |
+
image=image,
|
| 466 |
+
scale=rescale_factor,
|
| 467 |
+
input_data_format=input_data_format,
|
| 468 |
+
)
|
| 469 |
+
for image in images
|
| 470 |
+
]
|
| 471 |
+
|
| 472 |
+
if do_normalize:
|
| 473 |
+
images = [
|
| 474 |
+
self.normalize(
|
| 475 |
+
image=image,
|
| 476 |
+
mean=image_mean,
|
| 477 |
+
std=image_std,
|
| 478 |
+
input_data_format=input_data_format,
|
| 479 |
+
)
|
| 480 |
+
for image in images
|
| 481 |
+
]
|
| 482 |
+
|
| 483 |
+
images = [
|
| 484 |
+
to_channel_dimension_format(
|
| 485 |
+
image, data_format, input_channel_dim=input_data_format
|
| 486 |
+
)
|
| 487 |
+
for image in images
|
| 488 |
+
]
|
| 489 |
+
|
| 490 |
+
return images
|
| 491 |
+
|
| 492 |
+
def _resize_for_local_grids(
|
| 493 |
+
self,
|
| 494 |
+
image: np.array,
|
| 495 |
+
target_resolution: tuple,
|
| 496 |
+
resample,
|
| 497 |
+
input_data_format: ChannelDimension,
|
| 498 |
+
) -> np.array:
|
| 499 |
+
new_height, new_width = _get_local_grids_output_size(
|
| 500 |
+
image, target_resolution, input_data_format
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
resized_image = resize(
|
| 504 |
+
image,
|
| 505 |
+
(new_height, new_width),
|
| 506 |
+
resample=resample,
|
| 507 |
+
input_data_format=input_data_format,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
return resized_image
|
| 511 |
+
|
| 512 |
+
def _pad_for_patching(
|
| 513 |
+
self,
|
| 514 |
+
image: np.array,
|
| 515 |
+
target_resolution: tuple,
|
| 516 |
+
input_data_format: ChannelDimension,
|
| 517 |
+
) -> np.array:
|
| 518 |
+
"""
|
| 519 |
+
Pad an image to a target resolution while maintaining aspect ratio.
|
| 520 |
+
"""
|
| 521 |
+
target_height, target_width = target_resolution
|
| 522 |
+
|
| 523 |
+
background_color = tuple(int(x * 255) for x in self.image_mean)
|
| 524 |
+
padded_image = pad(
|
| 525 |
+
image,
|
| 526 |
+
target_size=(target_height, target_width),
|
| 527 |
+
background_color=background_color,
|
| 528 |
+
input_data_format=input_data_format,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
return padded_image
|
| 532 |
+
|
| 533 |
+
def get_image_grids(
|
| 534 |
+
self,
|
| 535 |
+
image: np.array,
|
| 536 |
+
possible_resolutions,
|
| 537 |
+
grid_size: int,
|
| 538 |
+
resample: PILImageResampling,
|
| 539 |
+
data_format: ChannelDimension,
|
| 540 |
+
input_data_format: ChannelDimension,
|
| 541 |
+
) -> List[np.array]:
|
| 542 |
+
if not isinstance(possible_resolutions, list):
|
| 543 |
+
raise ValueError(
|
| 544 |
+
"possible_resolutions must be a list of possible resolutions."
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
image_size = get_image_size(image, channel_dim=input_data_format)
|
| 548 |
+
best_resolution = select_best_resolution(image_size, possible_resolutions)
|
| 549 |
+
resized_image = self._resize_for_local_grids(
|
| 550 |
+
image,
|
| 551 |
+
best_resolution,
|
| 552 |
+
resample=resample,
|
| 553 |
+
input_data_format=input_data_format,
|
| 554 |
+
)
|
| 555 |
+
padded_image = self._pad_for_patching(
|
| 556 |
+
resized_image, best_resolution, input_data_format=input_data_format
|
| 557 |
+
)
|
| 558 |
+
local_grids = divide_to_grids(
|
| 559 |
+
padded_image, grid_size=grid_size, input_data_format=input_data_format
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
local_grids = [
|
| 563 |
+
to_channel_dimension_format(
|
| 564 |
+
grid, channel_dim=data_format, input_channel_dim=input_data_format
|
| 565 |
+
)
|
| 566 |
+
for grid in local_grids
|
| 567 |
+
]
|
| 568 |
+
|
| 569 |
+
return local_grids
|
| 570 |
+
|
| 571 |
+
def preprocess(
|
| 572 |
+
self,
|
| 573 |
+
images: ImageInput,
|
| 574 |
+
do_resize: bool = None,
|
| 575 |
+
size: Dict[str, int] = None,
|
| 576 |
+
anyres: bool = None,
|
| 577 |
+
unpad: bool = None,
|
| 578 |
+
video: bool = None,
|
| 579 |
+
num_queries_vis_abstractor: int = None,
|
| 580 |
+
possible_resolutions: List = None,
|
| 581 |
+
patch_size: int = None,
|
| 582 |
+
pad_to_square: bool = None,
|
| 583 |
+
resample: PILImageResampling = None,
|
| 584 |
+
do_center_crop: bool = None,
|
| 585 |
+
crop_size: int = None,
|
| 586 |
+
do_rescale: bool = None,
|
| 587 |
+
rescale_factor: float = None,
|
| 588 |
+
do_normalize: bool = None,
|
| 589 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 590 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 591 |
+
do_convert_rgb: bool = None,
|
| 592 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 593 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 594 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 595 |
+
return_dummy_image: bool = False,
|
| 596 |
+
num_queries_vis_abstractor_slow: int = 0,
|
| 597 |
+
first_last_frames_slow: bool = False,
|
| 598 |
+
is_first_or_last_frames: bool = False,
|
| 599 |
+
):
|
| 600 |
+
"""
|
| 601 |
+
HCXVisionImageProcessor 로 image tensor, original image size (width, height), visual tokens
|
| 602 |
+
|
| 603 |
+
:return pixel_values: List of 4D tensor 로 image tensor
|
| 604 |
+
:return image_sizes: List of Dict 로 image width, height [{"width": image 1 의 width, "height": image 1 의 height}, {"width": image 2 의 width, "height": image 2 의 height}, ...]
|
| 605 |
+
:return vision_query_lengths: List of int 로 각 image 가 LLM 입력으로 전달될때 변환되는 visual token 수
|
| 606 |
+
"""
|
| 607 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 608 |
+
size = size if size is not None else self.size
|
| 609 |
+
size = get_size_dict(size, param_name="size", default_to_square=False)
|
| 610 |
+
anyres = anyres if anyres is not None else self.anyres
|
| 611 |
+
unpad = unpad if unpad is not None else self.unpad
|
| 612 |
+
if video:
|
| 613 |
+
unpad = False
|
| 614 |
+
num_queries_vis_abstractor = (
|
| 615 |
+
num_queries_vis_abstractor
|
| 616 |
+
if num_queries_vis_abstractor is not None
|
| 617 |
+
else self.num_queries_vis_abstractor
|
| 618 |
+
)
|
| 619 |
+
possible_resolutions = (
|
| 620 |
+
possible_resolutions
|
| 621 |
+
if possible_resolutions is not None
|
| 622 |
+
else self.possible_resolutions
|
| 623 |
+
)
|
| 624 |
+
patch_size = patch_size if patch_size is not None else self.patch_size
|
| 625 |
+
pad_to_square = (
|
| 626 |
+
pad_to_square if pad_to_square is not None else self.pad_to_square
|
| 627 |
+
)
|
| 628 |
+
resample = resample if resample is not None else self.resample
|
| 629 |
+
do_center_crop = (
|
| 630 |
+
do_center_crop if do_center_crop is not None else self.do_center_crop
|
| 631 |
+
)
|
| 632 |
+
crop_size = crop_size if crop_size is not None else self.crop_size
|
| 633 |
+
crop_size = get_size_dict(
|
| 634 |
+
crop_size, param_name="crop_size", default_to_square=True
|
| 635 |
+
)
|
| 636 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 637 |
+
rescale_factor = (
|
| 638 |
+
rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 639 |
+
)
|
| 640 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 641 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 642 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 643 |
+
do_convert_rgb = (
|
| 644 |
+
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
if return_dummy_image:
|
| 648 |
+
images = Image.new("RGB", (224, 224), (0, 0, 0))
|
| 649 |
+
|
| 650 |
+
images = make_list_of_images(images)
|
| 651 |
+
|
| 652 |
+
if not valid_images(images):
|
| 653 |
+
raise ValueError(
|
| 654 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 655 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
if do_convert_rgb:
|
| 659 |
+
images = [convert_to_rgb(image) for image in images]
|
| 660 |
+
|
| 661 |
+
images = [to_numpy_array(image) for image in images]
|
| 662 |
+
|
| 663 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 664 |
+
logger.warning_once(
|
| 665 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 666 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
if input_data_format is None:
|
| 670 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 671 |
+
|
| 672 |
+
new_images = []
|
| 673 |
+
image_sizes = [
|
| 674 |
+
get_image_size(image, channel_dim=input_data_format) for image in images
|
| 675 |
+
]
|
| 676 |
+
vision_query_lengths = []
|
| 677 |
+
|
| 678 |
+
assert crop_size["height"] == crop_size["width"]
|
| 679 |
+
|
| 680 |
+
if anyres:
|
| 681 |
+
anyres_global_images = copy.deepcopy(images)
|
| 682 |
+
if pad_to_square:
|
| 683 |
+
background_color = tuple(int(x * 255) for x in self.image_mean)
|
| 684 |
+
anyres_global_images = [
|
| 685 |
+
resize_longside(
|
| 686 |
+
copy.deepcopy(image),
|
| 687 |
+
size["shortest_edge"],
|
| 688 |
+
resample,
|
| 689 |
+
input_data_format,
|
| 690 |
+
)
|
| 691 |
+
for image in anyres_global_images
|
| 692 |
+
]
|
| 693 |
+
anyres_global_images = [
|
| 694 |
+
expand2square(
|
| 695 |
+
image,
|
| 696 |
+
background_color=background_color,
|
| 697 |
+
input_data_format=input_data_format,
|
| 698 |
+
)[0]
|
| 699 |
+
for image in anyres_global_images
|
| 700 |
+
]
|
| 701 |
+
else:
|
| 702 |
+
anyres_global_images = [
|
| 703 |
+
self.resize(
|
| 704 |
+
image=image,
|
| 705 |
+
size={
|
| 706 |
+
"height": size["shortest_edge"],
|
| 707 |
+
"width": size["shortest_edge"],
|
| 708 |
+
},
|
| 709 |
+
resample=resample,
|
| 710 |
+
input_data_format=input_data_format,
|
| 711 |
+
)
|
| 712 |
+
for image in anyres_global_images
|
| 713 |
+
]
|
| 714 |
+
else:
|
| 715 |
+
anyres_global_images = [None for _ in range(len(images))]
|
| 716 |
+
if pad_to_square:
|
| 717 |
+
background_color = tuple(int(x * 255) for x in self.image_mean)
|
| 718 |
+
images = [
|
| 719 |
+
resize_longside(
|
| 720 |
+
image, size["shortest_edge"], resample, input_data_format
|
| 721 |
+
)
|
| 722 |
+
for image in images
|
| 723 |
+
]
|
| 724 |
+
images = [
|
| 725 |
+
expand2square(
|
| 726 |
+
image,
|
| 727 |
+
background_color=background_color,
|
| 728 |
+
input_data_format=input_data_format,
|
| 729 |
+
)[0]
|
| 730 |
+
for image in images
|
| 731 |
+
]
|
| 732 |
+
|
| 733 |
+
for image, anyres_global_image, image_size in zip(
|
| 734 |
+
images, anyres_global_images, image_sizes
|
| 735 |
+
):
|
| 736 |
+
if anyres:
|
| 737 |
+
image_grids = self.get_image_grids(
|
| 738 |
+
image,
|
| 739 |
+
possible_resolutions,
|
| 740 |
+
grid_size=crop_size["height"],
|
| 741 |
+
resample=resample,
|
| 742 |
+
data_format=input_data_format,
|
| 743 |
+
input_data_format=input_data_format,
|
| 744 |
+
)
|
| 745 |
+
if not video:
|
| 746 |
+
image_grids = [anyres_global_image] + image_grids
|
| 747 |
+
else:
|
| 748 |
+
image_grids = [image]
|
| 749 |
+
|
| 750 |
+
pixel_values = self._preprocess(
|
| 751 |
+
image_grids,
|
| 752 |
+
do_resize=do_resize,
|
| 753 |
+
size=size,
|
| 754 |
+
resample=resample,
|
| 755 |
+
do_center_crop=do_center_crop,
|
| 756 |
+
crop_size=crop_size,
|
| 757 |
+
do_rescale=do_rescale,
|
| 758 |
+
rescale_factor=rescale_factor,
|
| 759 |
+
do_normalize=do_normalize,
|
| 760 |
+
image_mean=image_mean,
|
| 761 |
+
image_std=image_std,
|
| 762 |
+
data_format=data_format,
|
| 763 |
+
input_data_format=input_data_format,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
pixel_values = np.array(pixel_values)
|
| 767 |
+
new_images.append(pixel_values)
|
| 768 |
+
|
| 769 |
+
num_grids = pixel_values.shape[0]
|
| 770 |
+
|
| 771 |
+
vision_query_length = determine_anyres_num_vision_patches(
|
| 772 |
+
num_grids=num_grids,
|
| 773 |
+
image_size=image_size,
|
| 774 |
+
grid_size=crop_size["height"],
|
| 775 |
+
patch_size=patch_size,
|
| 776 |
+
possible_resolutions=possible_resolutions,
|
| 777 |
+
anyres=anyres,
|
| 778 |
+
unpad=unpad,
|
| 779 |
+
num_queries_vis_abstractor=num_queries_vis_abstractor,
|
| 780 |
+
num_queries_vis_abstractor_slow=num_queries_vis_abstractor_slow,
|
| 781 |
+
video=video,
|
| 782 |
+
first_last_frames_slow=first_last_frames_slow,
|
| 783 |
+
is_first_or_last_frames=is_first_or_last_frames,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
vision_query_lengths.append(vision_query_length)
|
| 787 |
+
|
| 788 |
+
if return_dummy_image:
|
| 789 |
+
vision_query_lengths = []
|
| 790 |
+
|
| 791 |
+
data = {
|
| 792 |
+
"pixel_values": [torch.tensor(new_image) for new_image in new_images],
|
| 793 |
+
"image_sizes": [
|
| 794 |
+
{"width": image_size[1], "height": image_size[0]}
|
| 795 |
+
for image_size in image_sizes
|
| 796 |
+
],
|
| 797 |
+
"vision_query_lengths": vision_query_lengths,
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
return BatchFeature(data=data)
|
| 801 |
+
|
| 802 |
+
def save_pretrained(
|
| 803 |
+
self,
|
| 804 |
+
save_directory: Union[str, os.PathLike],
|
| 805 |
+
*args,
|
| 806 |
+
**kwargs,
|
| 807 |
+
):
|
| 808 |
+
self.register_for_auto_class()
|
| 809 |
+
super().save_pretrained(save_directory, *args, **kwargs)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
class HCXVisionV2Processor(Qwen2_5_VLProcessor):
|
| 813 |
+
attributes = ["image_processor", "tokenizer", "video_processor"]
|
| 814 |
+
image_processor_class = "AutoImageProcessor"
|
| 815 |
+
video_processor_class = "AutoVideoProcessor"
|
| 816 |
+
tokenizer_class = (
|
| 817 |
+
"GPT2Tokenizer",
|
| 818 |
+
"GPT2TokenizerFast",
|
| 819 |
+
"PreTrainedTokenizer",
|
| 820 |
+
"PreTrainedTokenizerFast",
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
def __init__(
|
| 824 |
+
self,
|
| 825 |
+
image_processor=None,
|
| 826 |
+
tokenizer=None,
|
| 827 |
+
video_processor=None,
|
| 828 |
+
chat_template=None,
|
| 829 |
+
**kwargs,
|
| 830 |
+
):
|
| 831 |
+
self.tokenizer = tokenizer
|
| 832 |
+
super().__init__(
|
| 833 |
+
image_processor,
|
| 834 |
+
tokenizer,
|
| 835 |
+
video_processor,
|
| 836 |
+
chat_template=self.tokenizer.chat_template,
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
def save_pretrained(
|
| 840 |
+
self,
|
| 841 |
+
save_directory: Union[str, os.PathLike],
|
| 842 |
+
*args,
|
| 843 |
+
**kwargs,
|
| 844 |
+
):
|
| 845 |
+
self.register_for_auto_class()
|
| 846 |
+
super().save_pretrained(save_directory, *args, **kwargs)
|
| 847 |
+
|
| 848 |
+
def __call__(
|
| 849 |
+
self,
|
| 850 |
+
images: ImageInput = None,
|
| 851 |
+
text: Union[
|
| 852 |
+
TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
|
| 853 |
+
] = None,
|
| 854 |
+
videos: VideoInput = None,
|
| 855 |
+
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
|
| 856 |
+
) -> BatchFeature:
|
| 857 |
+
"""
|
| 858 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
| 859 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
| 860 |
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
| 861 |
+
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
| 862 |
+
|
| 863 |
+
Args:
|
| 864 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
| 865 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 866 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 867 |
+
text (`str`, `list[str]`, `list[list[str]]`):
|
| 868 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 869 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 870 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 871 |
+
videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
| 872 |
+
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
| 873 |
+
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
| 874 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 875 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 876 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 877 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 878 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 879 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 880 |
+
|
| 881 |
+
Returns:
|
| 882 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 883 |
+
|
| 884 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 885 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 886 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 887 |
+
`None`).
|
| 888 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 889 |
+
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
| 890 |
+
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
| 891 |
+
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
| 892 |
+
"""
|
| 893 |
+
output_kwargs = self._merge_kwargs(
|
| 894 |
+
Qwen2_5_VLProcessorKwargs,
|
| 895 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 896 |
+
**kwargs,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
image_inputs = videos_inputs = {}
|
| 900 |
+
if images is not None:
|
| 901 |
+
image_inputs = self.image_processor(
|
| 902 |
+
images=images, **output_kwargs["images_kwargs"]
|
| 903 |
+
)
|
| 904 |
+
image_grid_thw = image_inputs["image_grid_thw"]
|
| 905 |
+
|
| 906 |
+
if videos is not None:
|
| 907 |
+
videos_inputs = self.video_processor(
|
| 908 |
+
videos=videos, **output_kwargs["videos_kwargs"]
|
| 909 |
+
)
|
| 910 |
+
video_grid_thw = videos_inputs["video_grid_thw"]
|
| 911 |
+
|
| 912 |
+
if not isinstance(text, list):
|
| 913 |
+
text = [text]
|
| 914 |
+
|
| 915 |
+
text = text.copy()
|
| 916 |
+
|
| 917 |
+
if images is not None:
|
| 918 |
+
merge_length = self.image_processor.merge_size**2
|
| 919 |
+
index = 0
|
| 920 |
+
for i in range(len(text)):
|
| 921 |
+
while self.image_token in text[i]:
|
| 922 |
+
num_image_tokens = image_grid_thw[index].prod() // merge_length
|
| 923 |
+
text[i] = text[i].replace(
|
| 924 |
+
self.image_token, "<|placeholder|>" * num_image_tokens, 1
|
| 925 |
+
)
|
| 926 |
+
text[i] = text[i].replace(
|
| 927 |
+
'{"resolution": [w, h]}',
|
| 928 |
+
'{"resolution": ' + str(list(images[i].size)) + "}",
|
| 929 |
+
)
|
| 930 |
+
index += 1
|
| 931 |
+
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
| 932 |
+
|
| 933 |
+
if videos is not None:
|
| 934 |
+
merge_length = self.video_processor.merge_size**2
|
| 935 |
+
index = 0
|
| 936 |
+
for i in range(len(text)):
|
| 937 |
+
while self.video_token in text[i]:
|
| 938 |
+
num_video_tokens = video_grid_thw[index].prod() // merge_length
|
| 939 |
+
text[i] = text[i].replace(
|
| 940 |
+
self.video_token, "<|placeholder|>" * num_video_tokens, 1
|
| 941 |
+
)
|
| 942 |
+
index += 1
|
| 943 |
+
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
| 944 |
+
|
| 945 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 946 |
+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop(
|
| 947 |
+
"return_mm_token_type_ids", False
|
| 948 |
+
)
|
| 949 |
+
text_inputs = self.tokenizer(
|
| 950 |
+
text, **output_kwargs["text_kwargs"], return_tensors=None
|
| 951 |
+
)
|
| 952 |
+
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
| 953 |
+
|
| 954 |
+
if return_mm_token_type_ids:
|
| 955 |
+
array_ids = np.array(text_inputs["input_ids"])
|
| 956 |
+
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
| 957 |
+
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
| 958 |
+
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
| 959 |
+
|
| 960 |
+
return BatchFeature(
|
| 961 |
+
data={**text_inputs, **image_inputs, **videos_inputs},
|
| 962 |
+
tensor_type=return_tensors,
|
| 963 |
+
)
|
processor_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_vlm.HCXVisionV2Processor"
|
| 4 |
+
},
|
| 5 |
+
"processor_class": "HCXVisionV2Processor"
|
| 6 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"eos_token": {
|
| 3 |
+
"content": "<|im_end|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"pad_token": {
|
| 10 |
+
"content": "<|im_end|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"sep_token": {
|
| 17 |
+
"content": "<|endoftext|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<|endoftext|>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
ta_tok.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import inspect
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torchvision.transforms import Resize
|
| 9 |
+
from transformers import AutoConfig, AutoModel, Siglip2VisionConfig, Siglip2VisionModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def models_make(model_spec, args=None, load_sd=False) -> torch.nn.Module:
|
| 13 |
+
if args is not None:
|
| 14 |
+
model_args = copy.deepcopy(model_spec["args"])
|
| 15 |
+
model_args.update(args)
|
| 16 |
+
else:
|
| 17 |
+
model_args = model_spec["args"]
|
| 18 |
+
model_params = inspect.signature(models[model_spec["name"]]).parameters
|
| 19 |
+
if "kwargs" not in model_params:
|
| 20 |
+
model_args = {k: v for k, v in model_args.items() if k in model_params}
|
| 21 |
+
model = models[model_spec["name"]](**model_args)
|
| 22 |
+
if load_sd:
|
| 23 |
+
if (
|
| 24 |
+
("abs_pe" in model_spec["sd"])
|
| 25 |
+
and hasattr(model, "abs_pe")
|
| 26 |
+
and model_spec["sd"]["abs_pe"].shape != model.abs_pe.shape
|
| 27 |
+
):
|
| 28 |
+
del model_spec["sd"]["abs_pe"]
|
| 29 |
+
msg = model.load_state_dict(model_spec["sd"], strict=False)
|
| 30 |
+
print(msg)
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Bottleneck(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
bottleneck_dim: int,
|
| 38 |
+
input_dim: int,
|
| 39 |
+
output_dim: int,
|
| 40 |
+
token_nums: int,
|
| 41 |
+
regularizer=None,
|
| 42 |
+
**kwargs,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.token_nums = token_nums
|
| 46 |
+
self.input_dim = input_dim
|
| 47 |
+
self.output_dim = output_dim
|
| 48 |
+
if bottleneck_dim > 0:
|
| 49 |
+
self.bottleneck_dim = bottleneck_dim
|
| 50 |
+
else:
|
| 51 |
+
assert (
|
| 52 |
+
self.input_dim == self.output_dim
|
| 53 |
+
), "input_dim and output_dim must be the same when bottleneck_dim is not specified"
|
| 54 |
+
self.bottleneck_dim = self.input_dim
|
| 55 |
+
|
| 56 |
+
self.project_dim = self.bottleneck_dim
|
| 57 |
+
|
| 58 |
+
if self.bottleneck_dim > 0:
|
| 59 |
+
self.in_linear = nn.Linear(self.input_dim, self.project_dim)
|
| 60 |
+
self.out_linear = nn.Linear(self.bottleneck_dim, self.output_dim)
|
| 61 |
+
else:
|
| 62 |
+
self.in_linear = self.out_linear = lambda x: x
|
| 63 |
+
|
| 64 |
+
regularizer["args"]["dim"] = self.bottleneck_dim
|
| 65 |
+
regularizer["args"]["token_nums"] = self.token_nums
|
| 66 |
+
self.regularizer = models_make(regularizer)
|
| 67 |
+
|
| 68 |
+
def project_in(self, x):
|
| 69 |
+
assert len(x.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
|
| 70 |
+
z = self.in_linear(x)
|
| 71 |
+
return z
|
| 72 |
+
|
| 73 |
+
def project_out(self, z_cat):
|
| 74 |
+
z = self.out_linear(z_cat)
|
| 75 |
+
return z
|
| 76 |
+
|
| 77 |
+
def decode(self, bottleneck_rep):
|
| 78 |
+
regularized_z = self.regularizer.decode(bottleneck_rep)
|
| 79 |
+
return self.project_out(regularized_z)
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
z = self.project_in(x)
|
| 83 |
+
projected_z = z
|
| 84 |
+
regularized_output = self.regularizer(z)
|
| 85 |
+
x_hat = self.project_out(regularized_output["regularized_z"])
|
| 86 |
+
bottleneck_rep = regularized_output.pop("bottleneck_rep")
|
| 87 |
+
return {
|
| 88 |
+
"output": x_hat,
|
| 89 |
+
"bottleneck_rep": bottleneck_rep,
|
| 90 |
+
"projected_z": projected_z,
|
| 91 |
+
**regularized_output,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class SimVectorQuantizer(nn.Module):
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
dim,
|
| 99 |
+
codebook_size,
|
| 100 |
+
l2_normalized=False,
|
| 101 |
+
same_index_shape=True,
|
| 102 |
+
stochastic=False,
|
| 103 |
+
stochastic_temperature=1.0,
|
| 104 |
+
**kwargs,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.codebook_size = codebook_size
|
| 108 |
+
self.dim = dim
|
| 109 |
+
assert isinstance(l2_normalized, bool)
|
| 110 |
+
self.l2_normalized = l2_normalized
|
| 111 |
+
self.stochastic = stochastic
|
| 112 |
+
self.eval_deterministic = False
|
| 113 |
+
self.default_stochastic_temperature = stochastic_temperature
|
| 114 |
+
|
| 115 |
+
if self.stochastic:
|
| 116 |
+
if stochastic_temperature > 0:
|
| 117 |
+
self.stochastic_temperature_inv = 1 / stochastic_temperature
|
| 118 |
+
else:
|
| 119 |
+
self.stochastic_temperature_inv = nn.Parameter(torch.tensor(10.0))
|
| 120 |
+
|
| 121 |
+
self.embedding = nn.Embedding(self.codebook_size, self.dim)
|
| 122 |
+
self.embedding_proj = nn.Linear(self.dim, self.dim)
|
| 123 |
+
|
| 124 |
+
self.same_index_shape = same_index_shape
|
| 125 |
+
|
| 126 |
+
def set_eval_deterministic(self, deterministic=True):
|
| 127 |
+
self.eval_deterministic = deterministic
|
| 128 |
+
|
| 129 |
+
def set_stochastic_temperature(self, temperature):
|
| 130 |
+
self.stochastic_temperature_inv = 1 / temperature
|
| 131 |
+
|
| 132 |
+
@torch.autocast(device_type="cuda", enabled=False)
|
| 133 |
+
def get_emb(self):
|
| 134 |
+
emb = self.embedding_proj(self.embedding.weight)
|
| 135 |
+
if self.l2_normalized:
|
| 136 |
+
emb = F.normalize(emb, p=2, dim=-1)
|
| 137 |
+
return emb
|
| 138 |
+
|
| 139 |
+
@torch.autocast(device_type="cuda", enabled=False)
|
| 140 |
+
def forward(self, z):
|
| 141 |
+
emb = self.get_emb()
|
| 142 |
+
z = z.to(emb)
|
| 143 |
+
assert len(z.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
|
| 144 |
+
if self.l2_normalized:
|
| 145 |
+
z = F.normalize(z, p=2, dim=-1)
|
| 146 |
+
|
| 147 |
+
z_flattened = rearrange(z, "b n d -> (b n) d")
|
| 148 |
+
|
| 149 |
+
if self.stochastic:
|
| 150 |
+
assert self.l2_normalized, "Stochastic sampling requires l2 normalization"
|
| 151 |
+
cos_sim = torch.einsum("bd,nd->bn", z_flattened, emb)
|
| 152 |
+
probs = F.softmax(cos_sim * self.stochastic_temperature_inv, dim=-1)
|
| 153 |
+
if self.eval_deterministic and not self.training:
|
| 154 |
+
q_indices = torch.argmax(probs, dim=-1)
|
| 155 |
+
else:
|
| 156 |
+
q_indices = torch.multinomial(probs, 1).squeeze(-1)
|
| 157 |
+
else:
|
| 158 |
+
d = (
|
| 159 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
| 160 |
+
+ torch.sum(emb**2, dim=1)
|
| 161 |
+
- 2
|
| 162 |
+
* torch.einsum("bd,dn->bn", z_flattened, rearrange(emb, "n d -> d n"))
|
| 163 |
+
)
|
| 164 |
+
q_indices = torch.argmin(d, dim=1)
|
| 165 |
+
|
| 166 |
+
quantized = F.embedding(
|
| 167 |
+
q_indices,
|
| 168 |
+
emb,
|
| 169 |
+
self.embedding.padding_idx,
|
| 170 |
+
self.embedding.max_norm,
|
| 171 |
+
self.embedding.norm_type,
|
| 172 |
+
self.embedding.scale_grad_by_freq,
|
| 173 |
+
self.embedding.sparse,
|
| 174 |
+
).view(z.shape)
|
| 175 |
+
|
| 176 |
+
quantized = z + (quantized - z).detach()
|
| 177 |
+
|
| 178 |
+
if self.same_index_shape:
|
| 179 |
+
q_indices = q_indices.reshape(quantized.shape[0], quantized.shape[1])
|
| 180 |
+
|
| 181 |
+
return_dict = {
|
| 182 |
+
"unregularized_z": z,
|
| 183 |
+
"emb": emb,
|
| 184 |
+
"regularized_z": quantized,
|
| 185 |
+
"bottleneck_rep": q_indices,
|
| 186 |
+
}
|
| 187 |
+
return return_dict
|
| 188 |
+
|
| 189 |
+
def get_codebook_entry(self, indices, shape=None):
|
| 190 |
+
indices_shape = indices.shape
|
| 191 |
+
indices_flatten = rearrange(indices, "... -> (...)")
|
| 192 |
+
|
| 193 |
+
emb = self.get_emb()
|
| 194 |
+
z_q = F.embedding(indices_flatten, emb)
|
| 195 |
+
if self.l2_normalized:
|
| 196 |
+
z_q = F.normalize(z_q, p=2, dim=-1)
|
| 197 |
+
|
| 198 |
+
if shape is not None:
|
| 199 |
+
z_q = z_q.reshape(shape)
|
| 200 |
+
else:
|
| 201 |
+
z_q = z_q.reshape([*indices_shape, self.dim])
|
| 202 |
+
return z_q
|
| 203 |
+
|
| 204 |
+
def decode(self, indices):
|
| 205 |
+
return self.get_codebook_entry(indices)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
models = {"simvq": SimVectorQuantizer, "bottleneck": Bottleneck}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class ScalingLayer(nn.Module):
|
| 212 |
+
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.register_buffer("shift", torch.Tensor(mean)[None, :, None, None])
|
| 215 |
+
self.register_buffer("scale", torch.Tensor(std)[None, :, None, None])
|
| 216 |
+
|
| 217 |
+
def forward(self, inp):
|
| 218 |
+
return (inp - self.shift) / self.scale
|
| 219 |
+
|
| 220 |
+
def inv(self, inp):
|
| 221 |
+
return inp * self.scale + self.shift
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class TextAlignedTokenizer(nn.Module):
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
bottleneck,
|
| 228 |
+
bottleneck_token_num=256,
|
| 229 |
+
input_size=384,
|
| 230 |
+
teacher="google/siglip2-so400m-patch14-384",
|
| 231 |
+
input_type="quant",
|
| 232 |
+
pool_scale=1,
|
| 233 |
+
decoder_depth=3,
|
| 234 |
+
select_layer_id=-2,
|
| 235 |
+
*args,
|
| 236 |
+
**kwargs,
|
| 237 |
+
):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.input_size = input_size
|
| 240 |
+
self.bottleneck_token_num = bottleneck_token_num
|
| 241 |
+
self.teacher = teacher
|
| 242 |
+
self.input_type = input_type
|
| 243 |
+
self.pool_scale = pool_scale
|
| 244 |
+
self.decoder_depth = decoder_depth
|
| 245 |
+
self.select_layer_id = select_layer_id
|
| 246 |
+
|
| 247 |
+
self.bottleneck_dim = bottleneck["args"]["bottleneck_dim"]
|
| 248 |
+
|
| 249 |
+
self.encoder_config = AutoConfig.from_pretrained(teacher)
|
| 250 |
+
self.encoder = AutoModel.from_config(self.encoder_config).vision_model
|
| 251 |
+
|
| 252 |
+
self.encoder_hidden_dim = self.encoder.config.hidden_size
|
| 253 |
+
|
| 254 |
+
self.decoder_config = Siglip2VisionConfig()
|
| 255 |
+
self.decoder_config.update(
|
| 256 |
+
{
|
| 257 |
+
"patch_size": 1,
|
| 258 |
+
"num_hidden_layers": self.decoder_depth,
|
| 259 |
+
"num_channels": self.bottleneck_dim,
|
| 260 |
+
"hidden_size": self.encoder_hidden_dim,
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
self.decoder = Siglip2VisionModel(self.decoder_config)
|
| 264 |
+
|
| 265 |
+
self.encode_task_layer = nn.Sequential(
|
| 266 |
+
nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim), nn.Tanh()
|
| 267 |
+
)
|
| 268 |
+
self.decode_task_layer = nn.Sequential(
|
| 269 |
+
nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
|
| 270 |
+
nn.Tanh(),
|
| 271 |
+
nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
bottleneck_args = {
|
| 275 |
+
"token_nums": self.bottleneck_token_num,
|
| 276 |
+
"input_dim": self.encoder_hidden_dim,
|
| 277 |
+
"output_dim": self.bottleneck_dim,
|
| 278 |
+
}
|
| 279 |
+
self.bottleneck = models_make(bottleneck, args=bottleneck_args)
|
| 280 |
+
|
| 281 |
+
self.scale_layer = ScalingLayer(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 282 |
+
self.image_resize = Resize((self.input_size, self.input_size))
|
| 283 |
+
|
| 284 |
+
def set_vq_eval_deterministic(self, deterministic=True):
|
| 285 |
+
self.bottleneck.regularizer.set_eval_deterministic(deterministic)
|
| 286 |
+
|
| 287 |
+
@property
|
| 288 |
+
def device(self):
|
| 289 |
+
return next(self.parameters()).device
|
| 290 |
+
|
| 291 |
+
@property
|
| 292 |
+
def dtype(self):
|
| 293 |
+
return next(self.parameters()).dtype
|
| 294 |
+
|
| 295 |
+
@classmethod
|
| 296 |
+
def from_checkpoint(cls, ckpt, load_teacher=True, **kwargs):
|
| 297 |
+
ckpt = torch.load(ckpt, map_location="cpu", weights_only=False)
|
| 298 |
+
ckpt_kwargs = ckpt["model"]["args"]
|
| 299 |
+
print(ckpt_kwargs)
|
| 300 |
+
model = cls(**kwargs, **ckpt_kwargs)
|
| 301 |
+
sd = ckpt["model"]["sd"]
|
| 302 |
+
if not load_teacher:
|
| 303 |
+
sd = {k: v for k, v in sd.items() if not k.startswith("teacher")}
|
| 304 |
+
model.load_state_dict(sd, strict=True)
|
| 305 |
+
return model
|
| 306 |
+
|
| 307 |
+
def encode(self, x, **kwargs):
|
| 308 |
+
if x.ndim == 5:
|
| 309 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 310 |
+
x = self.scale_layer(x)
|
| 311 |
+
if tuple(x.shape[-2:]) != (self.input_size, self.input_size):
|
| 312 |
+
x = self.image_resize(x)
|
| 313 |
+
vq_feats = self.encoder(x, output_hidden_states=True).hidden_states[
|
| 314 |
+
self.select_layer_id
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
pool_scale = self.pool_scale
|
| 318 |
+
pool_scale = kwargs.get("pool_scale", pool_scale)
|
| 319 |
+
if pool_scale != 1:
|
| 320 |
+
vq_feats = self.avg_pool(vq_feats, pool_scale)
|
| 321 |
+
vq_feats = self.encode_task_layer(vq_feats.to(x))
|
| 322 |
+
|
| 323 |
+
bottleneck_out = self.bottleneck(vq_feats)
|
| 324 |
+
z = bottleneck_out.pop("output")
|
| 325 |
+
|
| 326 |
+
return {
|
| 327 |
+
"encoded": z,
|
| 328 |
+
"pool_scale": pool_scale,
|
| 329 |
+
"vq_feats": vq_feats,
|
| 330 |
+
**bottleneck_out,
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
def avg_pool(self, z, pool_scale=1):
|
| 334 |
+
if z.ndim == 3:
|
| 335 |
+
b, n, c = z.shape
|
| 336 |
+
p = int(n**0.5)
|
| 337 |
+
z = rearrange(z, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
|
| 338 |
+
else:
|
| 339 |
+
b, c, p, _ = z.shape
|
| 340 |
+
p_s = int(p // pool_scale)
|
| 341 |
+
z = F.avg_pool2d(
|
| 342 |
+
z, kernel_size=(pool_scale, pool_scale), stride=(pool_scale, pool_scale)
|
| 343 |
+
).contiguous()
|
| 344 |
+
z = rearrange(z, "b c p1 p2 -> b (p1 p2) c")
|
| 345 |
+
return z
|
| 346 |
+
|
| 347 |
+
def decode(self, z):
|
| 348 |
+
if z.ndim == 4:
|
| 349 |
+
z = rearrange(z, "b c p1 p2 -> b (p1 p2) c")
|
| 350 |
+
attention_mask = torch.ones(z.shape[:2], dtype=torch.int, device=z.device)
|
| 351 |
+
p = int(z.shape[1] ** 0.5)
|
| 352 |
+
spatial_shape = torch.tensor([[p, p]] * z.shape[0], device=self.device)
|
| 353 |
+
z = self.decoder(
|
| 354 |
+
z, attention_mask, spatial_shape, output_hidden_states=True
|
| 355 |
+
).last_hidden_state
|
| 356 |
+
z = self.decode_task_layer(z)
|
| 357 |
+
return z
|
| 358 |
+
|
| 359 |
+
def decode_from_bottleneck(self, bottleneck_rep):
|
| 360 |
+
z = self.bottleneck.decode(bottleneck_rep)
|
| 361 |
+
p = int(z.shape[1] ** 0.5)
|
| 362 |
+
z = rearrange(z, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
|
| 363 |
+
return self.decode(z)
|
| 364 |
+
|
| 365 |
+
def forward(self, data, **kwargs):
|
| 366 |
+
encode_output = self.encode(data, **kwargs)
|
| 367 |
+
vq_feats = encode_output["encoded"]
|
| 368 |
+
p = int(vq_feats.shape[1] ** 0.5)
|
| 369 |
+
vq_feats = rearrange(vq_feats, "b (h w) c -> b c h w", h=p, w=p)
|
| 370 |
+
pred_feats = self.decode(vq_feats)
|
| 371 |
+
|
| 372 |
+
if self.input_type == "quant":
|
| 373 |
+
z = encode_output["regularized_z"]
|
| 374 |
+
elif self.input_type == "indices":
|
| 375 |
+
z = encode_output["bottleneck_rep"]
|
| 376 |
+
elif self.input_type == "rec":
|
| 377 |
+
z = pred_feats
|
| 378 |
+
encode_output["encoded"] = z
|
| 379 |
+
return encode_output
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:666f303c324b9b2e2e8f13950cd44a18896a6fc1a70aae70583a77663d0ebe31
|
| 3 |
+
size 23621510
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d7bac0a38c3af8f44d4b3b23d536111c1493fea74d3e7e2a71d804f63dada55
|
| 3 |
+
size 13220225
|
video_preprocessor_config.json
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_valid_kwargs_names": [
|
| 3 |
+
"do_convert_rgb",
|
| 4 |
+
"do_resize",
|
| 5 |
+
"size",
|
| 6 |
+
"size_divisor",
|
| 7 |
+
"default_to_square",
|
| 8 |
+
"resample",
|
| 9 |
+
"do_rescale",
|
| 10 |
+
"rescale_factor",
|
| 11 |
+
"do_normalize",
|
| 12 |
+
"image_mean",
|
| 13 |
+
"image_std",
|
| 14 |
+
"do_pad",
|
| 15 |
+
"do_center_crop",
|
| 16 |
+
"crop_size",
|
| 17 |
+
"data_format",
|
| 18 |
+
"input_data_format",
|
| 19 |
+
"device",
|
| 20 |
+
"min_pixels",
|
| 21 |
+
"max_pixels",
|
| 22 |
+
"patch_size",
|
| 23 |
+
"temporal_patch_size",
|
| 24 |
+
"merge_size"
|
| 25 |
+
],
|
| 26 |
+
"auto_map": {
|
| 27 |
+
"AutoProcessor": "processing_vlm.HCXVisionV2Processor"
|
| 28 |
+
},
|
| 29 |
+
"crop_size": null,
|
| 30 |
+
"data_format": "channels_first",
|
| 31 |
+
"default_to_square": true,
|
| 32 |
+
"device": null,
|
| 33 |
+
"do_center_crop": null,
|
| 34 |
+
"do_convert_rgb": true,
|
| 35 |
+
"do_normalize": true,
|
| 36 |
+
"do_pad": null,
|
| 37 |
+
"do_rescale": true,
|
| 38 |
+
"do_resize": true,
|
| 39 |
+
"image_mean": [
|
| 40 |
+
0.48145466,
|
| 41 |
+
0.4578275,
|
| 42 |
+
0.40821073
|
| 43 |
+
],
|
| 44 |
+
"image_processor_type": "Qwen2VLImageProcessor",
|
| 45 |
+
"image_std": [
|
| 46 |
+
0.26862954,
|
| 47 |
+
0.26130258,
|
| 48 |
+
0.27577711
|
| 49 |
+
],
|
| 50 |
+
"input_data_format": null,
|
| 51 |
+
"max_pixels": 12845056,
|
| 52 |
+
"merge_size": 2,
|
| 53 |
+
"min_pixels": 3136,
|
| 54 |
+
"model_valid_processing_keys": [
|
| 55 |
+
"do_convert_rgb",
|
| 56 |
+
"do_resize",
|
| 57 |
+
"size",
|
| 58 |
+
"size_divisor",
|
| 59 |
+
"default_to_square",
|
| 60 |
+
"resample",
|
| 61 |
+
"do_rescale",
|
| 62 |
+
"rescale_factor",
|
| 63 |
+
"do_normalize",
|
| 64 |
+
"image_mean",
|
| 65 |
+
"image_std",
|
| 66 |
+
"do_pad",
|
| 67 |
+
"do_center_crop",
|
| 68 |
+
"crop_size",
|
| 69 |
+
"data_format",
|
| 70 |
+
"input_data_format",
|
| 71 |
+
"device",
|
| 72 |
+
"min_pixels",
|
| 73 |
+
"max_pixels",
|
| 74 |
+
"patch_size",
|
| 75 |
+
"temporal_patch_size",
|
| 76 |
+
"merge_size"
|
| 77 |
+
],
|
| 78 |
+
"patch_size": 14,
|
| 79 |
+
"processor_class": "HCXVisionV2Processor",
|
| 80 |
+
"resample": 3,
|
| 81 |
+
"rescale_factor": 0.00392156862745098,
|
| 82 |
+
"size": {
|
| 83 |
+
"longest_edge": 12845056,
|
| 84 |
+
"shortest_edge": 3136
|
| 85 |
+
},
|
| 86 |
+
"size_divisor": null,
|
| 87 |
+
"temporal_patch_size": 2,
|
| 88 |
+
"video_processor_type": "Qwen2VLVideoProcessor"
|
| 89 |
+
}
|