Upload folder using huggingface_hub
Browse files- .gitattributes +21 -35
- LICENSE +201 -0
- README.md +110 -3
- added_tokens.json +1 -0
- config.json +36 -0
- configuration_ernie4_5_moe.py +194 -0
- generation_config.json +11 -0
- model-00001-of-00011.safetensors +3 -0
- model-00002-of-00011.safetensors +3 -0
- model-00003-of-00011.safetensors +3 -0
- model-00004-of-00011.safetensors +3 -0
- model-00005-of-00011.safetensors +3 -0
- model-00006-of-00011.safetensors +3 -0
- model-00007-of-00011.safetensors +3 -0
- model-00008-of-00011.safetensors +3 -0
- model-00009-of-00011.safetensors +3 -0
- model-00010-of-00011.safetensors +3 -0
- model-00011-of-00011.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_ernie4_5_moe.py +1412 -0
- special_tokens_map.json +1 -0
- tokenization_ernie4_5.py +374 -0
- tokenizer.model +3 -0
- tokenizer_config.json +22 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,21 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
model-00001-of-00009.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
model-00003-of-00009.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
model-00006-of-00009.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
model-00008-of-00009.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
model-00009-of-00009.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
model-00002-of-00009.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
model-00004-of-00009.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
model-00005-of-00009.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
model-00007-of-00009.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
model-00001-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
model-00002-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
model-00003-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
model-00004-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
model-00005-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
model-00006-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
model-00007-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
model-00008-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
model-00009-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
model-00010-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
model-00011-of-00011.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
tokenizer.model filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,110 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ERNIE-4.5-21B-A3B-Base
|
| 2 |
+
|
| 3 |
+
## ERNIE 4.5 Highlights
|
| 4 |
+
|
| 5 |
+
The advanced capabilities of the ERNIE 4.5 models, particularly the MoE-based A47B and A3B series, are underpinned by several key technical innovations:
|
| 6 |
+
|
| 7 |
+
- **Multimodal MoE Pretraining:** Our models are jointly trained on both textual and visual modalities to better capture the nuances of multimodal information and improve performance on tasks involving text generation, image understanding, and cross-modal reasoning. To achieve this without one modality hindering the learning of another, we designed a heterogeneous MoE structure, incorporated three-dimensional rotary embeddings, and employed router orthogonal loss and multimodal token-balanced loss. These architectural choices ensure that both modalities are effectively represented, allowing for mutual reinforcement during training.
|
| 8 |
+
- **Scaling-Efficient Architecture and Infrastructure:** To train the large multimodal MoE models efficiently, we introduce a novel heterogeneous hybrid parallelism and multi-level load balancing strategy for efficient training of ERNIE 4.5 models. By using on-device expert parallelism, memory-efficient pipeline scheduling, and FP8 mixed precision, we achieve ideal pre-training performance. For inference, we propose a quantization method with collaborative parallelism among multiple experts to achieve lossless quantization. Built on PaddlePaddle, ERNIE 4.5 delivers high-performance inference across a wide range of hardware platforms.
|
| 9 |
+
- **Modality-Specific Post-training:** To meet the diverse requirements of real-world applications, we fine-tuned variants of the pretrained model for specific modalities. Our LLMs are optimized for general-purpose language understanding and generation. The VLMs focuses on visual-language understanding and supports both thinking and no-thinking mode. Each model employed a combination of Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) or a modified reinforcement learning method named Unified Preference Optimization (UPO) for post-training, using targeted datasets aligned with its intended usage scenario.
|
| 10 |
+
|
| 11 |
+
To ensure the stability of multimodal joint training, we adopt a staged training strategy. In the first and second stage, we train only the text-related parameters, enabling the model to develop strong fundamental language understanding as well as long-text processing capabilities. The final multimodal stage extends capabilities to images and videos by introducing additional parameters including a ViT for image feature extraction, an adapter for feature transformation, and visual experts for multimodal understanding. At this stage, text and visual modalities mutually enhance each other. After pretraining trillions tokens, we extracted the text-related parameters and finally obtained ERNIE-4.5-21B-A3B-Base.
|
| 12 |
+
|
| 13 |
+
## Model Overview
|
| 14 |
+
|
| 15 |
+
ERNIE-4.5-21B-A3B-Base is a text MoE Base model, with 21B total parameters and 3B activated parameters for each token. The following are the model configuration details:
|
| 16 |
+
|
| 17 |
+
| Key | Value |
|
| 18 |
+
| --------------------------------- | ----------- |
|
| 19 |
+
| Modality | Text |
|
| 20 |
+
| Training Stage | Pretraining |
|
| 21 |
+
| Params(Total / Activated) | 21B / 3B |
|
| 22 |
+
| Layers | 28 |
|
| 23 |
+
| Heads(Q/KV) | 20 / 4 |
|
| 24 |
+
| Text Experts(Total / Activated) | 64 / 6 |
|
| 25 |
+
| Vision Experts(Total / Activated) | 64 / 6 |
|
| 26 |
+
| Shared Experts | 2 |
|
| 27 |
+
| Context Length | 131072 |
|
| 28 |
+
|
| 29 |
+
## Quickstart
|
| 30 |
+
|
| 31 |
+
### Model Finetuning with ERNIEKit
|
| 32 |
+
|
| 33 |
+
[ERNIEKit](https://github.com/PaddlePaddle/ERNIE) is a training toolkit based on PaddlePaddle, specifically designed for the ERNIE series of open-source large models. It provides comprehensive support for scenarios such as instruction fine-tuning (SFT, LoRA) and alignment training (DPO), ensuring optimal performance.
|
| 34 |
+
|
| 35 |
+
Usage Examples:
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
# SFT
|
| 39 |
+
erniekit train --stage SFT --model_name_or_path /baidu/ERNIE-4.5-21B-A3B-Base --train_dataset_path your_dataset_path
|
| 40 |
+
# DPO
|
| 41 |
+
erniekit train --stage DPO --model_name_or_path /baidu/ERNIE-4.5-21B-A3B-Base --train_dataset_path your_dataset_path
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
For more detailed examples, including SFT with LoRA, multi-GPU configurations, and advanced scripts, please refer to the examples folder within the [ERNIEKit](https://github.com/PaddlePaddle/ERNIE) repository.
|
| 45 |
+
|
| 46 |
+
### FastDeploy Inference
|
| 47 |
+
|
| 48 |
+
Service deployment can be quickly completed using FastDeploy in the following command. For more detailed usage instructions, please refer to the [FastDeploy Repository](https://github.com/PaddlePaddle/FastDeploy).
|
| 49 |
+
|
| 50 |
+
**Note**: For single-card deployment, at least 80G of GPU memory resources are required.
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
python -m fastdeploy.entrypoints.openai.api_server \
|
| 54 |
+
--model BAIDU/ERNIE-4.5-21B-A3B-Base-Paddle \
|
| 55 |
+
--port 8180 \
|
| 56 |
+
--metrics-port 8181 \
|
| 57 |
+
--engine-worker-queue-port 8182 \
|
| 58 |
+
--max-model-len 32768 \ # Maximum supported number of tokens
|
| 59 |
+
--max-num-seqs 32 # Maximum concurrent processing capacity
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### Using `transformers` library
|
| 63 |
+
|
| 64 |
+
The following contains a code snippet illustrating how to use the model generate content based on given inputs.
|
| 65 |
+
|
| 66 |
+
```python
|
| 67 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 68 |
+
|
| 69 |
+
model_name = "baidu/ERNIE-4.5-21B-A3B-Base-PT"
|
| 70 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 71 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
| 72 |
+
|
| 73 |
+
prompt = "Large language model is"
|
| 74 |
+
model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 75 |
+
|
| 76 |
+
generated_ids = model.generate(
|
| 77 |
+
model_inputs.input_ids,
|
| 78 |
+
max_new_tokens=1024
|
| 79 |
+
)
|
| 80 |
+
result = tokenizer.decode(generated_ids[0].tolist(), skip_special_tokens=True)
|
| 81 |
+
print("result:", result)
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### vLLM inference
|
| 85 |
+
|
| 86 |
+
vLLM is currently being adapted, priority can be given to using our fork repository [vllm](https://github.com/CSWYF3634076/vllm/tree/ernie)
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
vllm serve baidu/ERNIE-4.5-21B-A3B-Base-PT --trust-remote-code
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## License
|
| 93 |
+
|
| 94 |
+
The ERNIE 4.5 models are provided under the Apache License 2.0. This license permits commercial use, subject to its terms and conditions. Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 95 |
+
|
| 96 |
+
## Citation
|
| 97 |
+
|
| 98 |
+
If you find ERNIE 4.5 useful or wish to use it in your projects, please kindly cite our technical report:
|
| 99 |
+
|
| 100 |
+
```bibtex
|
| 101 |
+
@misc{ernie2025technicalreport,
|
| 102 |
+
title={ERNIE 4.5 Technical Report},
|
| 103 |
+
author={Baidu ERNIE Team},
|
| 104 |
+
year={2025},
|
| 105 |
+
eprint={},
|
| 106 |
+
archivePrefix={arXiv},
|
| 107 |
+
primaryClass={cs.CL},
|
| 108 |
+
url={}
|
| 109 |
+
}
|
| 110 |
+
```
|
added_tokens.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"<|IMAGE_PLACEHOLDER|>": 100295, "<|AUDIO_PLACEHOLDER|>": 100296, "<|LOC_0|>": 100297, "<|LOC_1|>": 100298, "<|LOC_2|>": 100299, "<|LOC_3|>": 100300, "<|LOC_4|>": 100301, "<|LOC_5|>": 100302, "<|LOC_6|>": 100303, "<|LOC_7|>": 100304, "<|LOC_8|>": 100305, "<|LOC_9|>": 100306, "<|LOC_10|>": 100307, "<|LOC_11|>": 100308, "<|LOC_12|>": 100309, "<|LOC_13|>": 100310, "<|LOC_14|>": 100311, "<|LOC_15|>": 100312, "<|LOC_16|>": 100313, "<|LOC_17|>": 100314, "<|LOC_18|>": 100315, "<|LOC_19|>": 100316, "<|LOC_20|>": 100317, "<|LOC_21|>": 100318, "<|LOC_22|>": 100319, "<|LOC_23|>": 100320, "<|LOC_24|>": 100321, "<|LOC_25|>": 100322, "<|LOC_26|>": 100323, "<|LOC_27|>": 100324, "<|LOC_28|>": 100325, "<|LOC_29|>": 100326, "<|LOC_30|>": 100327, "<|LOC_31|>": 100328, "<|LOC_32|>": 100329, "<|LOC_33|>": 100330, "<|LOC_34|>": 100331, "<|LOC_35|>": 100332, "<|LOC_36|>": 100333, "<|LOC_37|>": 100334, "<|LOC_38|>": 100335, "<|LOC_39|>": 100336, "<|LOC_40|>": 100337, "<|LOC_41|>": 100338, "<|LOC_42|>": 100339, "<|LOC_43|>": 100340, "<|LOC_44|>": 100341, "<|LOC_45|>": 100342, "<|LOC_46|>": 100343, "<|LOC_47|>": 100344, "<|LOC_48|>": 100345, "<|LOC_49|>": 100346, "<|LOC_50|>": 100347, "<|LOC_51|>": 100348, "<|LOC_52|>": 100349, "<|LOC_53|>": 100350, "<|LOC_54|>": 100351, "<|LOC_55|>": 100352, "<|LOC_56|>": 100353, "<|LOC_57|>": 100354, "<|LOC_58|>": 100355, "<|LOC_59|>": 100356, "<|LOC_60|>": 100357, "<|LOC_61|>": 100358, "<|LOC_62|>": 100359, "<|LOC_63|>": 100360, "<|LOC_64|>": 100361, "<|LOC_65|>": 100362, "<|LOC_66|>": 100363, "<|LOC_67|>": 100364, "<|LOC_68|>": 100365, "<|LOC_69|>": 100366, "<|LOC_70|>": 100367, "<|LOC_71|>": 100368, "<|LOC_72|>": 100369, "<|LOC_73|>": 100370, "<|LOC_74|>": 100371, "<|LOC_75|>": 100372, "<|LOC_76|>": 100373, "<|LOC_77|>": 100374, "<|LOC_78|>": 100375, "<|LOC_79|>": 100376, "<|LOC_80|>": 100377, "<|LOC_81|>": 100378, "<|LOC_82|>": 100379, "<|LOC_83|>": 100380, "<|LOC_84|>": 100381, "<|LOC_85|>": 100382, "<|LOC_86|>": 100383, "<|LOC_87|>": 100384, "<|LOC_88|>": 100385, "<|LOC_89|>": 100386, "<|LOC_90|>": 100387, "<|LOC_91|>": 100388, "<|LOC_92|>": 100389, "<|LOC_93|>": 100390, "<|LOC_94|>": 100391, "<|LOC_95|>": 100392, "<|LOC_96|>": 100393, "<|LOC_97|>": 100394, "<|LOC_98|>": 100395, "<|LOC_99|>": 100396, "<|LOC_100|>": 100397, "<|LOC_101|>": 100398, "<|LOC_102|>": 100399, "<|LOC_103|>": 100400, "<|LOC_104|>": 100401, "<|LOC_105|>": 100402, "<|LOC_106|>": 100403, "<|LOC_107|>": 100404, "<|LOC_108|>": 100405, "<|LOC_109|>": 100406, "<|LOC_110|>": 100407, "<|LOC_111|>": 100408, "<|LOC_112|>": 100409, "<|LOC_113|>": 100410, "<|LOC_114|>": 100411, "<|LOC_115|>": 100412, "<|LOC_116|>": 100413, "<|LOC_117|>": 100414, "<|LOC_118|>": 100415, "<|LOC_119|>": 100416, "<|LOC_120|>": 100417, "<|LOC_121|>": 100418, "<|LOC_122|>": 100419, "<|LOC_123|>": 100420, "<|LOC_124|>": 100421, "<|LOC_125|>": 100422, "<|LOC_126|>": 100423, "<|LOC_127|>": 100424, "<|LOC_128|>": 100425, "<|LOC_129|>": 100426, "<|LOC_130|>": 100427, "<|LOC_131|>": 100428, "<|LOC_132|>": 100429, "<|LOC_133|>": 100430, "<|LOC_134|>": 100431, "<|LOC_135|>": 100432, "<|LOC_136|>": 100433, "<|LOC_137|>": 100434, "<|LOC_138|>": 100435, "<|LOC_139|>": 100436, "<|LOC_140|>": 100437, "<|LOC_141|>": 100438, "<|LOC_142|>": 100439, "<|LOC_143|>": 100440, "<|LOC_144|>": 100441, "<|LOC_145|>": 100442, "<|LOC_146|>": 100443, "<|LOC_147|>": 100444, "<|LOC_148|>": 100445, "<|LOC_149|>": 100446, "<|LOC_150|>": 100447, "<|LOC_151|>": 100448, "<|LOC_152|>": 100449, "<|LOC_153|>": 100450, "<|LOC_154|>": 100451, "<|LOC_155|>": 100452, "<|LOC_156|>": 100453, "<|LOC_157|>": 100454, "<|LOC_158|>": 100455, "<|LOC_159|>": 100456, "<|LOC_160|>": 100457, "<|LOC_161|>": 100458, "<|LOC_162|>": 100459, "<|LOC_163|>": 100460, "<|LOC_164|>": 100461, "<|LOC_165|>": 100462, "<|LOC_166|>": 100463, "<|LOC_167|>": 100464, "<|LOC_168|>": 100465, "<|LOC_169|>": 100466, "<|LOC_170|>": 100467, "<|LOC_171|>": 100468, "<|LOC_172|>": 100469, "<|LOC_173|>": 100470, "<|LOC_174|>": 100471, "<|LOC_175|>": 100472, "<|LOC_176|>": 100473, "<|LOC_177|>": 100474, "<|LOC_178|>": 100475, "<|LOC_179|>": 100476, "<|LOC_180|>": 100477, "<|LOC_181|>": 100478, "<|LOC_182|>": 100479, "<|LOC_183|>": 100480, "<|LOC_184|>": 100481, "<|LOC_185|>": 100482, "<|LOC_186|>": 100483, "<|LOC_187|>": 100484, "<|LOC_188|>": 100485, "<|LOC_189|>": 100486, "<|LOC_190|>": 100487, "<|LOC_191|>": 100488, "<|LOC_192|>": 100489, "<|LOC_193|>": 100490, "<|LOC_194|>": 100491, "<|LOC_195|>": 100492, "<|LOC_196|>": 100493, "<|LOC_197|>": 100494, "<|LOC_198|>": 100495, "<|LOC_199|>": 100496, "<|LOC_200|>": 100497, "<|LOC_201|>": 100498, "<|LOC_202|>": 100499, "<|LOC_203|>": 100500, "<|LOC_204|>": 100501, "<|LOC_205|>": 100502, "<|LOC_206|>": 100503, "<|LOC_207|>": 100504, "<|LOC_208|>": 100505, "<|LOC_209|>": 100506, "<|LOC_210|>": 100507, "<|LOC_211|>": 100508, "<|LOC_212|>": 100509, "<|LOC_213|>": 100510, "<|LOC_214|>": 100511, "<|LOC_215|>": 100512, "<|LOC_216|>": 100513, "<|LOC_217|>": 100514, "<|LOC_218|>": 100515, "<|LOC_219|>": 100516, "<|LOC_220|>": 100517, "<|LOC_221|>": 100518, "<|LOC_222|>": 100519, "<|LOC_223|>": 100520, "<|LOC_224|>": 100521, "<|LOC_225|>": 100522, "<|LOC_226|>": 100523, "<|LOC_227|>": 100524, "<|LOC_228|>": 100525, "<|LOC_229|>": 100526, "<|LOC_230|>": 100527, "<|LOC_231|>": 100528, "<|LOC_232|>": 100529, "<|LOC_233|>": 100530, "<|LOC_234|>": 100531, "<|LOC_235|>": 100532, "<|LOC_236|>": 100533, "<|LOC_237|>": 100534, "<|LOC_238|>": 100535, "<|LOC_239|>": 100536, "<|LOC_240|>": 100537, "<|LOC_241|>": 100538, "<|LOC_242|>": 100539, "<|LOC_243|>": 100540, "<|LOC_244|>": 100541, "<|LOC_245|>": 100542, "<|LOC_246|>": 100543, "<|LOC_247|>": 100544, "<|LOC_248|>": 100545, "<|LOC_249|>": 100546, "<|LOC_250|>": 100547, "<|LOC_251|>": 100548, "<|LOC_252|>": 100549, "<|LOC_253|>": 100550, "<|LOC_254|>": 100551, "<|LOC_255|>": 100552, "<|LOC_256|>": 100553, "<|LOC_257|>": 100554, "<|LOC_258|>": 100555, "<|LOC_259|>": 100556, "<|LOC_260|>": 100557, "<|LOC_261|>": 100558, "<|LOC_262|>": 100559, "<|LOC_263|>": 100560, "<|LOC_264|>": 100561, "<|LOC_265|>": 100562, "<|LOC_266|>": 100563, "<|LOC_267|>": 100564, "<|LOC_268|>": 100565, "<|LOC_269|>": 100566, "<|LOC_270|>": 100567, "<|LOC_271|>": 100568, "<|LOC_272|>": 100569, "<|LOC_273|>": 100570, "<|LOC_274|>": 100571, "<|LOC_275|>": 100572, "<|LOC_276|>": 100573, "<|LOC_277|>": 100574, "<|LOC_278|>": 100575, "<|LOC_279|>": 100576, "<|LOC_280|>": 100577, "<|LOC_281|>": 100578, "<|LOC_282|>": 100579, "<|LOC_283|>": 100580, "<|LOC_284|>": 100581, "<|LOC_285|>": 100582, "<|LOC_286|>": 100583, "<|LOC_287|>": 100584, "<|LOC_288|>": 100585, "<|LOC_289|>": 100586, "<|LOC_290|>": 100587, "<|LOC_291|>": 100588, "<|LOC_292|>": 100589, "<|LOC_293|>": 100590, "<|LOC_294|>": 100591, "<|LOC_295|>": 100592, "<|LOC_296|>": 100593, "<|LOC_297|>": 100594, "<|LOC_298|>": 100595, "<|LOC_299|>": 100596, "<|LOC_300|>": 100597, "<|LOC_301|>": 100598, "<|LOC_302|>": 100599, "<|LOC_303|>": 100600, "<|LOC_304|>": 100601, "<|LOC_305|>": 100602, "<|LOC_306|>": 100603, "<|LOC_307|>": 100604, "<|LOC_308|>": 100605, "<|LOC_309|>": 100606, "<|LOC_310|>": 100607, "<|LOC_311|>": 100608, "<|LOC_312|>": 100609, "<|LOC_313|>": 100610, "<|LOC_314|>": 100611, "<|LOC_315|>": 100612, "<|LOC_316|>": 100613, "<|LOC_317|>": 100614, "<|LOC_318|>": 100615, "<|LOC_319|>": 100616, "<|LOC_320|>": 100617, "<|LOC_321|>": 100618, "<|LOC_322|>": 100619, "<|LOC_323|>": 100620, "<|LOC_324|>": 100621, "<|LOC_325|>": 100622, "<|LOC_326|>": 100623, "<|LOC_327|>": 100624, "<|LOC_328|>": 100625, "<|LOC_329|>": 100626, "<|LOC_330|>": 100627, "<|LOC_331|>": 100628, "<|LOC_332|>": 100629, "<|LOC_333|>": 100630, "<|LOC_334|>": 100631, "<|LOC_335|>": 100632, "<|LOC_336|>": 100633, "<|LOC_337|>": 100634, "<|LOC_338|>": 100635, "<|LOC_339|>": 100636, "<|LOC_340|>": 100637, "<|LOC_341|>": 100638, "<|LOC_342|>": 100639, "<|LOC_343|>": 100640, "<|LOC_344|>": 100641, "<|LOC_345|>": 100642, "<|LOC_346|>": 100643, "<|LOC_347|>": 100644, "<|LOC_348|>": 100645, "<|LOC_349|>": 100646, "<|LOC_350|>": 100647, "<|LOC_351|>": 100648, "<|LOC_352|>": 100649, "<|LOC_353|>": 100650, "<|LOC_354|>": 100651, "<|LOC_355|>": 100652, "<|LOC_356|>": 100653, "<|LOC_357|>": 100654, "<|LOC_358|>": 100655, "<|LOC_359|>": 100656, "<|LOC_360|>": 100657, "<|LOC_361|>": 100658, "<|LOC_362|>": 100659, "<|LOC_363|>": 100660, "<|LOC_364|>": 100661, "<|LOC_365|>": 100662, "<|LOC_366|>": 100663, "<|LOC_367|>": 100664, "<|LOC_368|>": 100665, "<|LOC_369|>": 100666, "<|LOC_370|>": 100667, "<|LOC_371|>": 100668, "<|LOC_372|>": 100669, "<|LOC_373|>": 100670, "<|LOC_374|>": 100671, "<|LOC_375|>": 100672, "<|LOC_376|>": 100673, "<|LOC_377|>": 100674, "<|LOC_378|>": 100675, "<|LOC_379|>": 100676, "<|LOC_380|>": 100677, "<|LOC_381|>": 100678, "<|LOC_382|>": 100679, "<|LOC_383|>": 100680, "<|LOC_384|>": 100681, "<|LOC_385|>": 100682, "<|LOC_386|>": 100683, "<|LOC_387|>": 100684, "<|LOC_388|>": 100685, "<|LOC_389|>": 100686, "<|LOC_390|>": 100687, "<|LOC_391|>": 100688, "<|LOC_392|>": 100689, "<|LOC_393|>": 100690, "<|LOC_394|>": 100691, "<|LOC_395|>": 100692, "<|LOC_396|>": 100693, "<|LOC_397|>": 100694, "<|LOC_398|>": 100695, "<|LOC_399|>": 100696, "<|LOC_400|>": 100697, "<|LOC_401|>": 100698, "<|LOC_402|>": 100699, "<|LOC_403|>": 100700, "<|LOC_404|>": 100701, "<|LOC_405|>": 100702, "<|LOC_406|>": 100703, "<|LOC_407|>": 100704, "<|LOC_408|>": 100705, "<|LOC_409|>": 100706, "<|LOC_410|>": 100707, "<|LOC_411|>": 100708, "<|LOC_412|>": 100709, "<|LOC_413|>": 100710, "<|LOC_414|>": 100711, "<|LOC_415|>": 100712, "<|LOC_416|>": 100713, "<|LOC_417|>": 100714, "<|LOC_418|>": 100715, "<|LOC_419|>": 100716, "<|LOC_420|>": 100717, "<|LOC_421|>": 100718, "<|LOC_422|>": 100719, "<|LOC_423|>": 100720, "<|LOC_424|>": 100721, "<|LOC_425|>": 100722, "<|LOC_426|>": 100723, "<|LOC_427|>": 100724, "<|LOC_428|>": 100725, "<|LOC_429|>": 100726, "<|LOC_430|>": 100727, "<|LOC_431|>": 100728, "<|LOC_432|>": 100729, "<|LOC_433|>": 100730, "<|LOC_434|>": 100731, "<|LOC_435|>": 100732, "<|LOC_436|>": 100733, "<|LOC_437|>": 100734, "<|LOC_438|>": 100735, "<|LOC_439|>": 100736, "<|LOC_440|>": 100737, "<|LOC_441|>": 100738, "<|LOC_442|>": 100739, "<|LOC_443|>": 100740, "<|LOC_444|>": 100741, "<|LOC_445|>": 100742, "<|LOC_446|>": 100743, "<|LOC_447|>": 100744, "<|LOC_448|>": 100745, "<|LOC_449|>": 100746, "<|LOC_450|>": 100747, "<|LOC_451|>": 100748, "<|LOC_452|>": 100749, "<|LOC_453|>": 100750, "<|LOC_454|>": 100751, "<|LOC_455|>": 100752, "<|LOC_456|>": 100753, "<|LOC_457|>": 100754, "<|LOC_458|>": 100755, "<|LOC_459|>": 100756, "<|LOC_460|>": 100757, "<|LOC_461|>": 100758, "<|LOC_462|>": 100759, "<|LOC_463|>": 100760, "<|LOC_464|>": 100761, "<|LOC_465|>": 100762, "<|LOC_466|>": 100763, "<|LOC_467|>": 100764, "<|LOC_468|>": 100765, "<|LOC_469|>": 100766, "<|LOC_470|>": 100767, "<|LOC_471|>": 100768, "<|LOC_472|>": 100769, "<|LOC_473|>": 100770, "<|LOC_474|>": 100771, "<|LOC_475|>": 100772, "<|LOC_476|>": 100773, "<|LOC_477|>": 100774, "<|LOC_478|>": 100775, "<|LOC_479|>": 100776, "<|LOC_480|>": 100777, "<|LOC_481|>": 100778, "<|LOC_482|>": 100779, "<|LOC_483|>": 100780, "<|LOC_484|>": 100781, "<|LOC_485|>": 100782, "<|LOC_486|>": 100783, "<|LOC_487|>": 100784, "<|LOC_488|>": 100785, "<|LOC_489|>": 100786, "<|LOC_490|>": 100787, "<|LOC_491|>": 100788, "<|LOC_492|>": 100789, "<|LOC_493|>": 100790, "<|LOC_494|>": 100791, "<|LOC_495|>": 100792, "<|LOC_496|>": 100793, "<|LOC_497|>": 100794, "<|LOC_498|>": 100795, "<|LOC_499|>": 100796, "<|LOC_500|>": 100797, "<|LOC_501|>": 100798, "<|LOC_502|>": 100799, "<|LOC_503|>": 100800, "<|LOC_504|>": 100801, "<|LOC_505|>": 100802, "<|LOC_506|>": 100803, "<|LOC_507|>": 100804, "<|LOC_508|>": 100805, "<|LOC_509|>": 100806, "<|LOC_510|>": 100807, "<|LOC_511|>": 100808, "<|LOC_512|>": 100809, "<|LOC_513|>": 100810, "<|LOC_514|>": 100811, "<|LOC_515|>": 100812, "<|LOC_516|>": 100813, "<|LOC_517|>": 100814, "<|LOC_518|>": 100815, "<|LOC_519|>": 100816, "<|LOC_520|>": 100817, "<|LOC_521|>": 100818, "<|LOC_522|>": 100819, "<|LOC_523|>": 100820, "<|LOC_524|>": 100821, "<|LOC_525|>": 100822, "<|LOC_526|>": 100823, "<|LOC_527|>": 100824, "<|LOC_528|>": 100825, "<|LOC_529|>": 100826, "<|LOC_530|>": 100827, "<|LOC_531|>": 100828, "<|LOC_532|>": 100829, "<|LOC_533|>": 100830, "<|LOC_534|>": 100831, "<|LOC_535|>": 100832, "<|LOC_536|>": 100833, "<|LOC_537|>": 100834, "<|LOC_538|>": 100835, "<|LOC_539|>": 100836, "<|LOC_540|>": 100837, "<|LOC_541|>": 100838, "<|LOC_542|>": 100839, "<|LOC_543|>": 100840, "<|LOC_544|>": 100841, "<|LOC_545|>": 100842, "<|LOC_546|>": 100843, "<|LOC_547|>": 100844, "<|LOC_548|>": 100845, "<|LOC_549|>": 100846, "<|LOC_550|>": 100847, "<|LOC_551|>": 100848, "<|LOC_552|>": 100849, "<|LOC_553|>": 100850, "<|LOC_554|>": 100851, "<|LOC_555|>": 100852, "<|LOC_556|>": 100853, "<|LOC_557|>": 100854, "<|LOC_558|>": 100855, "<|LOC_559|>": 100856, "<|LOC_560|>": 100857, "<|LOC_561|>": 100858, "<|LOC_562|>": 100859, "<|LOC_563|>": 100860, "<|LOC_564|>": 100861, "<|LOC_565|>": 100862, "<|LOC_566|>": 100863, "<|LOC_567|>": 100864, "<|LOC_568|>": 100865, "<|LOC_569|>": 100866, "<|LOC_570|>": 100867, "<|LOC_571|>": 100868, "<|LOC_572|>": 100869, "<|LOC_573|>": 100870, "<|LOC_574|>": 100871, "<|LOC_575|>": 100872, "<|LOC_576|>": 100873, "<|LOC_577|>": 100874, "<|LOC_578|>": 100875, "<|LOC_579|>": 100876, "<|LOC_580|>": 100877, "<|LOC_581|>": 100878, "<|LOC_582|>": 100879, "<|LOC_583|>": 100880, "<|LOC_584|>": 100881, "<|LOC_585|>": 100882, "<|LOC_586|>": 100883, "<|LOC_587|>": 100884, "<|LOC_588|>": 100885, "<|LOC_589|>": 100886, "<|LOC_590|>": 100887, "<|LOC_591|>": 100888, "<|LOC_592|>": 100889, "<|LOC_593|>": 100890, "<|LOC_594|>": 100891, "<|LOC_595|>": 100892, "<|LOC_596|>": 100893, "<|LOC_597|>": 100894, "<|LOC_598|>": 100895, "<|LOC_599|>": 100896, "<|LOC_600|>": 100897, "<|LOC_601|>": 100898, "<|LOC_602|>": 100899, "<|LOC_603|>": 100900, "<|LOC_604|>": 100901, "<|LOC_605|>": 100902, "<|LOC_606|>": 100903, "<|LOC_607|>": 100904, "<|LOC_608|>": 100905, "<|LOC_609|>": 100906, "<|LOC_610|>": 100907, "<|LOC_611|>": 100908, "<|LOC_612|>": 100909, "<|LOC_613|>": 100910, "<|LOC_614|>": 100911, "<|LOC_615|>": 100912, "<|LOC_616|>": 100913, "<|LOC_617|>": 100914, "<|LOC_618|>": 100915, "<|LOC_619|>": 100916, "<|LOC_620|>": 100917, "<|LOC_621|>": 100918, "<|LOC_622|>": 100919, "<|LOC_623|>": 100920, "<|LOC_624|>": 100921, "<|LOC_625|>": 100922, "<|LOC_626|>": 100923, "<|LOC_627|>": 100924, "<|LOC_628|>": 100925, "<|LOC_629|>": 100926, "<|LOC_630|>": 100927, "<|LOC_631|>": 100928, "<|LOC_632|>": 100929, "<|LOC_633|>": 100930, "<|LOC_634|>": 100931, "<|LOC_635|>": 100932, "<|LOC_636|>": 100933, "<|LOC_637|>": 100934, "<|LOC_638|>": 100935, "<|LOC_639|>": 100936, "<|LOC_640|>": 100937, "<|LOC_641|>": 100938, "<|LOC_642|>": 100939, "<|LOC_643|>": 100940, "<|LOC_644|>": 100941, "<|LOC_645|>": 100942, "<|LOC_646|>": 100943, "<|LOC_647|>": 100944, "<|LOC_648|>": 100945, "<|LOC_649|>": 100946, "<|LOC_650|>": 100947, "<|LOC_651|>": 100948, "<|LOC_652|>": 100949, "<|LOC_653|>": 100950, "<|LOC_654|>": 100951, "<|LOC_655|>": 100952, "<|LOC_656|>": 100953, "<|LOC_657|>": 100954, "<|LOC_658|>": 100955, "<|LOC_659|>": 100956, "<|LOC_660|>": 100957, "<|LOC_661|>": 100958, "<|LOC_662|>": 100959, "<|LOC_663|>": 100960, "<|LOC_664|>": 100961, "<|LOC_665|>": 100962, "<|LOC_666|>": 100963, "<|LOC_667|>": 100964, "<|LOC_668|>": 100965, "<|LOC_669|>": 100966, "<|LOC_670|>": 100967, "<|LOC_671|>": 100968, "<|LOC_672|>": 100969, "<|LOC_673|>": 100970, "<|LOC_674|>": 100971, "<|LOC_675|>": 100972, "<|LOC_676|>": 100973, "<|LOC_677|>": 100974, "<|LOC_678|>": 100975, "<|LOC_679|>": 100976, "<|LOC_680|>": 100977, "<|LOC_681|>": 100978, "<|LOC_682|>": 100979, "<|LOC_683|>": 100980, "<|LOC_684|>": 100981, "<|LOC_685|>": 100982, "<|LOC_686|>": 100983, "<|LOC_687|>": 100984, "<|LOC_688|>": 100985, "<|LOC_689|>": 100986, "<|LOC_690|>": 100987, "<|LOC_691|>": 100988, "<|LOC_692|>": 100989, "<|LOC_693|>": 100990, "<|LOC_694|>": 100991, "<|LOC_695|>": 100992, "<|LOC_696|>": 100993, "<|LOC_697|>": 100994, "<|LOC_698|>": 100995, "<|LOC_699|>": 100996, "<|LOC_700|>": 100997, "<|LOC_701|>": 100998, "<|LOC_702|>": 100999, "<|LOC_703|>": 101000, "<|LOC_704|>": 101001, "<|LOC_705|>": 101002, "<|LOC_706|>": 101003, "<|LOC_707|>": 101004, "<|LOC_708|>": 101005, "<|LOC_709|>": 101006, "<|LOC_710|>": 101007, "<|LOC_711|>": 101008, "<|LOC_712|>": 101009, "<|LOC_713|>": 101010, "<|LOC_714|>": 101011, "<|LOC_715|>": 101012, "<|LOC_716|>": 101013, "<|LOC_717|>": 101014, "<|LOC_718|>": 101015, "<|LOC_719|>": 101016, "<|LOC_720|>": 101017, "<|LOC_721|>": 101018, "<|LOC_722|>": 101019, "<|LOC_723|>": 101020, "<|LOC_724|>": 101021, "<|LOC_725|>": 101022, "<|LOC_726|>": 101023, "<|LOC_727|>": 101024, "<|LOC_728|>": 101025, "<|LOC_729|>": 101026, "<|LOC_730|>": 101027, "<|LOC_731|>": 101028, "<|LOC_732|>": 101029, "<|LOC_733|>": 101030, "<|LOC_734|>": 101031, "<|LOC_735|>": 101032, "<|LOC_736|>": 101033, "<|LOC_737|>": 101034, "<|LOC_738|>": 101035, "<|LOC_739|>": 101036, "<|LOC_740|>": 101037, "<|LOC_741|>": 101038, "<|LOC_742|>": 101039, "<|LOC_743|>": 101040, "<|LOC_744|>": 101041, "<|LOC_745|>": 101042, "<|LOC_746|>": 101043, "<|LOC_747|>": 101044, "<|LOC_748|>": 101045, "<|LOC_749|>": 101046, "<|LOC_750|>": 101047, "<|LOC_751|>": 101048, "<|LOC_752|>": 101049, "<|LOC_753|>": 101050, "<|LOC_754|>": 101051, "<|LOC_755|>": 101052, "<|LOC_756|>": 101053, "<|LOC_757|>": 101054, "<|LOC_758|>": 101055, "<|LOC_759|>": 101056, "<|LOC_760|>": 101057, "<|LOC_761|>": 101058, "<|LOC_762|>": 101059, "<|LOC_763|>": 101060, "<|LOC_764|>": 101061, "<|LOC_765|>": 101062, "<|LOC_766|>": 101063, "<|LOC_767|>": 101064, "<|LOC_768|>": 101065, "<|LOC_769|>": 101066, "<|LOC_770|>": 101067, "<|LOC_771|>": 101068, "<|LOC_772|>": 101069, "<|LOC_773|>": 101070, "<|LOC_774|>": 101071, "<|LOC_775|>": 101072, "<|LOC_776|>": 101073, "<|LOC_777|>": 101074, "<|LOC_778|>": 101075, "<|LOC_779|>": 101076, "<|LOC_780|>": 101077, "<|LOC_781|>": 101078, "<|LOC_782|>": 101079, "<|LOC_783|>": 101080, "<|LOC_784|>": 101081, "<|LOC_785|>": 101082, "<|LOC_786|>": 101083, "<|LOC_787|>": 101084, "<|LOC_788|>": 101085, "<|LOC_789|>": 101086, "<|LOC_790|>": 101087, "<|LOC_791|>": 101088, "<|LOC_792|>": 101089, "<|LOC_793|>": 101090, "<|LOC_794|>": 101091, "<|LOC_795|>": 101092, "<|LOC_796|>": 101093, "<|LOC_797|>": 101094, "<|LOC_798|>": 101095, "<|LOC_799|>": 101096, "<|LOC_800|>": 101097, "<|LOC_801|>": 101098, "<|LOC_802|>": 101099, "<|LOC_803|>": 101100, "<|LOC_804|>": 101101, "<|LOC_805|>": 101102, "<|LOC_806|>": 101103, "<|LOC_807|>": 101104, "<|LOC_808|>": 101105, "<|LOC_809|>": 101106, "<|LOC_810|>": 101107, "<|LOC_811|>": 101108, "<|LOC_812|>": 101109, "<|LOC_813|>": 101110, "<|LOC_814|>": 101111, "<|LOC_815|>": 101112, "<|LOC_816|>": 101113, "<|LOC_817|>": 101114, "<|LOC_818|>": 101115, "<|LOC_819|>": 101116, "<|LOC_820|>": 101117, "<|LOC_821|>": 101118, "<|LOC_822|>": 101119, "<|LOC_823|>": 101120, "<|LOC_824|>": 101121, "<|LOC_825|>": 101122, "<|LOC_826|>": 101123, "<|LOC_827|>": 101124, "<|LOC_828|>": 101125, "<|LOC_829|>": 101126, "<|LOC_830|>": 101127, "<|LOC_831|>": 101128, "<|LOC_832|>": 101129, "<|LOC_833|>": 101130, "<|LOC_834|>": 101131, "<|LOC_835|>": 101132, "<|LOC_836|>": 101133, "<|LOC_837|>": 101134, "<|LOC_838|>": 101135, "<|LOC_839|>": 101136, "<|LOC_840|>": 101137, "<|LOC_841|>": 101138, "<|LOC_842|>": 101139, "<|LOC_843|>": 101140, "<|LOC_844|>": 101141, "<|LOC_845|>": 101142, "<|LOC_846|>": 101143, "<|LOC_847|>": 101144, "<|LOC_848|>": 101145, "<|LOC_849|>": 101146, "<|LOC_850|>": 101147, "<|LOC_851|>": 101148, "<|LOC_852|>": 101149, "<|LOC_853|>": 101150, "<|LOC_854|>": 101151, "<|LOC_855|>": 101152, "<|LOC_856|>": 101153, "<|LOC_857|>": 101154, "<|LOC_858|>": 101155, "<|LOC_859|>": 101156, "<|LOC_860|>": 101157, "<|LOC_861|>": 101158, "<|LOC_862|>": 101159, "<|LOC_863|>": 101160, "<|LOC_864|>": 101161, "<|LOC_865|>": 101162, "<|LOC_866|>": 101163, "<|LOC_867|>": 101164, "<|LOC_868|>": 101165, "<|LOC_869|>": 101166, "<|LOC_870|>": 101167, "<|LOC_871|>": 101168, "<|LOC_872|>": 101169, "<|LOC_873|>": 101170, "<|LOC_874|>": 101171, "<|LOC_875|>": 101172, "<|LOC_876|>": 101173, "<|LOC_877|>": 101174, "<|LOC_878|>": 101175, "<|LOC_879|>": 101176, "<|LOC_880|>": 101177, "<|LOC_881|>": 101178, "<|LOC_882|>": 101179, "<|LOC_883|>": 101180, "<|LOC_884|>": 101181, "<|LOC_885|>": 101182, "<|LOC_886|>": 101183, "<|LOC_887|>": 101184, "<|LOC_888|>": 101185, "<|LOC_889|>": 101186, "<|LOC_890|>": 101187, "<|LOC_891|>": 101188, "<|LOC_892|>": 101189, "<|LOC_893|>": 101190, "<|LOC_894|>": 101191, "<|LOC_895|>": 101192, "<|LOC_896|>": 101193, "<|LOC_897|>": 101194, "<|LOC_898|>": 101195, "<|LOC_899|>": 101196, "<|LOC_900|>": 101197, "<|LOC_901|>": 101198, "<|LOC_902|>": 101199, "<|LOC_903|>": 101200, "<|LOC_904|>": 101201, "<|LOC_905|>": 101202, "<|LOC_906|>": 101203, "<|LOC_907|>": 101204, "<|LOC_908|>": 101205, "<|LOC_909|>": 101206, "<|LOC_910|>": 101207, "<|LOC_911|>": 101208, "<|LOC_912|>": 101209, "<|LOC_913|>": 101210, "<|LOC_914|>": 101211, "<|LOC_915|>": 101212, "<|LOC_916|>": 101213, "<|LOC_917|>": 101214, "<|LOC_918|>": 101215, "<|LOC_919|>": 101216, "<|LOC_920|>": 101217, "<|LOC_921|>": 101218, "<|LOC_922|>": 101219, "<|LOC_923|>": 101220, "<|LOC_924|>": 101221, "<|LOC_925|>": 101222, "<|LOC_926|>": 101223, "<|LOC_927|>": 101224, "<|LOC_928|>": 101225, "<|LOC_929|>": 101226, "<|LOC_930|>": 101227, "<|LOC_931|>": 101228, "<|LOC_932|>": 101229, "<|LOC_933|>": 101230, "<|LOC_934|>": 101231, "<|LOC_935|>": 101232, "<|LOC_936|>": 101233, "<|LOC_937|>": 101234, "<|LOC_938|>": 101235, "<|LOC_939|>": 101236, "<|LOC_940|>": 101237, "<|LOC_941|>": 101238, "<|LOC_942|>": 101239, "<|LOC_943|>": 101240, "<|LOC_944|>": 101241, "<|LOC_945|>": 101242, "<|LOC_946|>": 101243, "<|LOC_947|>": 101244, "<|LOC_948|>": 101245, "<|LOC_949|>": 101246, "<|LOC_950|>": 101247, "<|LOC_951|>": 101248, "<|LOC_952|>": 101249, "<|LOC_953|>": 101250, "<|LOC_954|>": 101251, "<|LOC_955|>": 101252, "<|LOC_956|>": 101253, "<|LOC_957|>": 101254, "<|LOC_958|>": 101255, "<|LOC_959|>": 101256, "<|LOC_960|>": 101257, "<|LOC_961|>": 101258, "<|LOC_962|>": 101259, "<|LOC_963|>": 101260, "<|LOC_964|>": 101261, "<|LOC_965|>": 101262, "<|LOC_966|>": 101263, "<|LOC_967|>": 101264, "<|LOC_968|>": 101265, "<|LOC_969|>": 101266, "<|LOC_970|>": 101267, "<|LOC_971|>": 101268, "<|LOC_972|>": 101269, "<|LOC_973|>": 101270, "<|LOC_974|>": 101271, "<|LOC_975|>": 101272, "<|LOC_976|>": 101273, "<|LOC_977|>": 101274, "<|LOC_978|>": 101275, "<|LOC_979|>": 101276, "<|LOC_980|>": 101277, "<|LOC_981|>": 101278, "<|LOC_982|>": 101279, "<|LOC_983|>": 101280, "<|LOC_984|>": 101281, "<|LOC_985|>": 101282, "<|LOC_986|>": 101283, "<|LOC_987|>": 101284, "<|LOC_988|>": 101285, "<|LOC_989|>": 101286, "<|LOC_990|>": 101287, "<|LOC_991|>": 101288, "<|LOC_992|>": 101289, "<|LOC_993|>": 101290, "<|LOC_994|>": 101291, "<|LOC_995|>": 101292, "<|LOC_996|>": 101293, "<|LOC_997|>": 101294, "<|LOC_998|>": 101295, "<|LOC_999|>": 101296, "<|LOC_1000|>": 101297, "<|LOC_BEGIN|>": 101298, "<|LOC_END|>": 101299, "<|LOC_SEP|>": 101300, "<|CROP_COL_SEP|>": 101301, "<|CROP_ROW_SEP|>": 101302, "<|IMAGE_SEP|>": 101303}
|
config.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Ernie4_5_MoeForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_ernie4_5_moe.Ernie4_5_MoeConfig",
|
| 7 |
+
"AutoModel": "modeling_ernie4_5_moe.Ernie4_5_Model",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_ernie4_5_moe.Ernie4_5_MoeForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"bos_token_id": 1,
|
| 11 |
+
"eos_token_id": 2,
|
| 12 |
+
"hidden_act": "silu",
|
| 13 |
+
"hidden_size": 2560,
|
| 14 |
+
"intermediate_size": 12288,
|
| 15 |
+
"max_position_embeddings": 131072,
|
| 16 |
+
"model_type": "ernie4_5_moe",
|
| 17 |
+
"num_attention_heads": 20,
|
| 18 |
+
"num_key_value_heads": 4,
|
| 19 |
+
"num_hidden_layers": 28,
|
| 20 |
+
"pad_token_id": 0,
|
| 21 |
+
"rms_norm_eps": 1e-05,
|
| 22 |
+
"use_cache": false,
|
| 23 |
+
"vocab_size": 103424,
|
| 24 |
+
"rope_theta": 500000,
|
| 25 |
+
"tie_word_embeddings": true,
|
| 26 |
+
"use_bias": false,
|
| 27 |
+
"moe_num_experts": 64,
|
| 28 |
+
"moe_num_shared_experts": 2,
|
| 29 |
+
"moe_layer_start_index": 1,
|
| 30 |
+
"moe_intermediate_size": 1536,
|
| 31 |
+
"moe_capacity": [64,64,64],
|
| 32 |
+
"moe_k": 6,
|
| 33 |
+
"moe_layer_interval": 1,
|
| 34 |
+
"moe_use_aux_free": true,
|
| 35 |
+
"torch_dtype": "bfloat16"
|
| 36 |
+
}
|
configuration_ernie4_5_moe.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from transformers import PretrainedConfig
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Ernie4_5_MoeConfig(PretrainedConfig):
|
| 19 |
+
r"""
|
| 20 |
+
This is the configuration class to store the configuration of a [`Ernie4_5_Model`].
|
| 21 |
+
It is used to instantiate an ERNIE-4.5 model according to the specified arguments,
|
| 22 |
+
defining the model architecture. Instantiating a configuration with the defaults
|
| 23 |
+
will yield a similar configuration to that of ERNIE-4.5-21B-A3B-Base-PT [baidu/ERNIE-4.5-21B-A3B-Base-PT].
|
| 24 |
+
|
| 25 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 26 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
vocab_size (int): Size of the vocabulary (number of unique tokens)
|
| 31 |
+
hidden_size (int): Dimensionality of the encoder layers and the pooler layer
|
| 32 |
+
intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer
|
| 33 |
+
max_position_embeddings (int): Maximum sequence length the model can handle
|
| 34 |
+
num_hidden_layers (int): Number of hidden layers in the Transformer encoder
|
| 35 |
+
num_attention_heads (int): Number of attention heads for each attention layer
|
| 36 |
+
rms_norm_eps (float): The epsilon used by the RMS normalization layers
|
| 37 |
+
use_cache (bool): Whether to use caching for faster generation (decoding)
|
| 38 |
+
use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation
|
| 39 |
+
pad_token_id (int): Token ID used for padding sequences
|
| 40 |
+
bos_token_id (int): Token ID used for beginning-of-sequence
|
| 41 |
+
eos_token_id (int): Token ID used for end-of-sequence
|
| 42 |
+
use_bias (bool): Whether to use bias terms in linear layers
|
| 43 |
+
rope_theta (float): The base period of the RoPE embeddings
|
| 44 |
+
weight_share_add_bias (bool): Whether to share bias weights in certain layers
|
| 45 |
+
ignored_index (int): Target value that is ignored during loss computation
|
| 46 |
+
attention_probs_dropout_prob (float): Dropout probability for attention weights
|
| 47 |
+
hidden_dropout_prob (float): Dropout probability for hidden layers
|
| 48 |
+
num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention)
|
| 49 |
+
max_sequence_length (int): Maximum sequence length for positional embeddings
|
| 50 |
+
moe_num_experts: Number of experts in MoE layers
|
| 51 |
+
moe_capacity: Capacity configuration for MoE layers
|
| 52 |
+
moe_layer_interval: Interval between MoE layers
|
| 53 |
+
moe_layer_start_index: Starting layer index for MoE
|
| 54 |
+
moe_layer_end_index: Ending layer index for MoE (-1 means last layer)
|
| 55 |
+
sinkhorn_2gate: Whether to use sinkhorn 2-gate routing
|
| 56 |
+
sinkhorn_temp: Temperature for sinkhorn routing
|
| 57 |
+
moe_dropout_prob: Dropout probability for MoE layers
|
| 58 |
+
moe_gate: Type of gating mechanism ('top2', etc.)
|
| 59 |
+
moe_intermediate_size: Intermediate size for MoE layers
|
| 60 |
+
moe_gate_act: Activation function for gating
|
| 61 |
+
moe_k: Number of experts to route to
|
| 62 |
+
**kwargs: Additional base model configuration parameters
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
model_type = "ernie4_5_moe"
|
| 66 |
+
use_keep_in_fp32_modules = True
|
| 67 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 68 |
+
|
| 69 |
+
attribute_map = {
|
| 70 |
+
"n_positions": "max_position_embeddings",
|
| 71 |
+
"n_embd": "hidden_size",
|
| 72 |
+
"n_layer": "num_hidden_layers",
|
| 73 |
+
"n_head": "num_attention_heads",
|
| 74 |
+
"n_inner": "intermediate_size",
|
| 75 |
+
"activation_function": "hidden_act",
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# Default tensor parallel plan for base model `ernie_4_5_moe`
|
| 79 |
+
base_model_tp_plan = {
|
| 80 |
+
"model.layers.*.self_attn.q_proj": "colwise_rep",
|
| 81 |
+
"model.layers.*.self_attn.k_proj": "colwise_rep",
|
| 82 |
+
"model.layers.*.self_attn.v_proj": "colwise_rep",
|
| 83 |
+
"model.layers.*.self_attn.o_proj": "rowwise_rep",
|
| 84 |
+
"model.layers.*.mlp.experts.*.gate_proj": "colwise",
|
| 85 |
+
"model.layers.*.mlp.experts.*.up_proj": "colwise",
|
| 86 |
+
"model.layers.*.mlp.experts.*.down_proj": "rowwise",
|
| 87 |
+
"model.layers.*.mlp.gate_proj": "colwise",
|
| 88 |
+
"model.layers.*.mlp.up_proj": "colwise",
|
| 89 |
+
"model.layers.*.mlp.down_proj": "rowwise",
|
| 90 |
+
}
|
| 91 |
+
base_model_pp_plan = {
|
| 92 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 93 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 94 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
vocab_size=32000,
|
| 100 |
+
hidden_size=768,
|
| 101 |
+
intermediate_size=11008,
|
| 102 |
+
num_hidden_layers=2,
|
| 103 |
+
num_attention_heads=2,
|
| 104 |
+
num_key_value_heads=None,
|
| 105 |
+
max_position_embeddings=32768,
|
| 106 |
+
use_sliding_window=None,
|
| 107 |
+
sliding_window=None,
|
| 108 |
+
rms_norm_eps=1e-6,
|
| 109 |
+
use_cache=False,
|
| 110 |
+
pad_token_id=0,
|
| 111 |
+
bos_token_id=1,
|
| 112 |
+
eos_token_id=2,
|
| 113 |
+
attention_probs_dropout_prob=0.0,
|
| 114 |
+
hidden_dropout_prob=0.0,
|
| 115 |
+
rope_theta=10000.0,
|
| 116 |
+
use_flash_attention=False,
|
| 117 |
+
use_rmsnorm=True,
|
| 118 |
+
use_bias=False,
|
| 119 |
+
weight_share_add_bias=True,
|
| 120 |
+
max_sequence_length=None,
|
| 121 |
+
ignored_index=-100,
|
| 122 |
+
use_moe=True,
|
| 123 |
+
moe_num_experts=64,
|
| 124 |
+
moe_capacity=(64, 64, 64),
|
| 125 |
+
moe_layer_interval=2,
|
| 126 |
+
moe_layer_start_index=0,
|
| 127 |
+
moe_layer_end_index=-1,
|
| 128 |
+
sinkhorn_2gate=True,
|
| 129 |
+
sinkhorn_temp=3e-2,
|
| 130 |
+
moe_dropout_prob=0.0,
|
| 131 |
+
moe_gate="top2",
|
| 132 |
+
moe_intermediate_size=3584,
|
| 133 |
+
moe_k=2,
|
| 134 |
+
moe_gate_act: str = "softmax",
|
| 135 |
+
moe_use_aux_free=False,
|
| 136 |
+
**kwargs,
|
| 137 |
+
):
|
| 138 |
+
self.vocab_size = vocab_size
|
| 139 |
+
self.max_position_embeddings = max_position_embeddings
|
| 140 |
+
self.use_sliding_window = use_sliding_window
|
| 141 |
+
self.sliding_window = sliding_window
|
| 142 |
+
self.hidden_size = hidden_size
|
| 143 |
+
self.intermediate_size = intermediate_size
|
| 144 |
+
self.num_hidden_layers = num_hidden_layers
|
| 145 |
+
self.num_attention_heads = num_attention_heads
|
| 146 |
+
|
| 147 |
+
if num_key_value_heads is None:
|
| 148 |
+
num_key_value_heads = num_attention_heads
|
| 149 |
+
|
| 150 |
+
self.num_key_value_heads = num_key_value_heads
|
| 151 |
+
self.use_rmsnorm = use_rmsnorm
|
| 152 |
+
self.rms_norm_eps = rms_norm_eps
|
| 153 |
+
self.rope_theta = rope_theta
|
| 154 |
+
self.max_sequence_length = max_sequence_length
|
| 155 |
+
self.pad_token_id = pad_token_id
|
| 156 |
+
self.bos_token_id = bos_token_id
|
| 157 |
+
self.eos_token_id = eos_token_id
|
| 158 |
+
self.ignored_index = ignored_index
|
| 159 |
+
self.use_cache = use_cache
|
| 160 |
+
self.use_bias = use_bias
|
| 161 |
+
self.weight_share_add_bias = weight_share_add_bias
|
| 162 |
+
self.use_flash_attention = use_flash_attention
|
| 163 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 164 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 165 |
+
|
| 166 |
+
self.use_moe = moe_num_experts > 0 and use_moe
|
| 167 |
+
self.moe_num_experts = moe_num_experts
|
| 168 |
+
self.moe_capacity = moe_capacity
|
| 169 |
+
self.sinkhorn_2gate = sinkhorn_2gate
|
| 170 |
+
self.sinkhorn_temp = sinkhorn_temp
|
| 171 |
+
self.moe_layer_interval = moe_layer_interval
|
| 172 |
+
self.moe_dropout_prob = moe_dropout_prob
|
| 173 |
+
self.moe_gate = moe_gate
|
| 174 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 175 |
+
self.moe_k = moe_k
|
| 176 |
+
self.moe_layer_start_index = moe_layer_start_index
|
| 177 |
+
self.moe_layer_end_index = (
|
| 178 |
+
self.num_hidden_layers - 1
|
| 179 |
+
if moe_layer_end_index == -1
|
| 180 |
+
else moe_layer_end_index
|
| 181 |
+
)
|
| 182 |
+
self.moe_gate_act = moe_gate_act
|
| 183 |
+
self.moe_use_aux_free = moe_use_aux_free
|
| 184 |
+
|
| 185 |
+
# Set default for tied embeddings if not specified.
|
| 186 |
+
if "tie_word_embeddings" not in kwargs:
|
| 187 |
+
kwargs["tie_word_embeddings"] = False
|
| 188 |
+
|
| 189 |
+
super().__init__(
|
| 190 |
+
pad_token_id=pad_token_id,
|
| 191 |
+
bos_token_id=bos_token_id,
|
| 192 |
+
eos_token_id=eos_token_id,
|
| 193 |
+
**kwargs,
|
| 194 |
+
)
|
generation_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_sample": true,
|
| 3 |
+
"top_p": 0.8,
|
| 4 |
+
"temperature": 0.8,
|
| 5 |
+
"bos_token_id": 1,
|
| 6 |
+
"eos_token_id": 2,
|
| 7 |
+
"pad_token_id": 0,
|
| 8 |
+
"repetition_penalty": 1.0,
|
| 9 |
+
"frequency_penalty": 0.0,
|
| 10 |
+
"presence_penalty": 0.0
|
| 11 |
+
}
|
model-00001-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b84bc85796a4712e157ca7c32647a74cbb0770e46402753f9593e80986db72ee
|
| 3 |
+
size 4189462080
|
model-00002-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a6b6b05bbd39efd7901e767f792098fc21c3619e12c9cdb6e9b7fc28b18c289
|
| 3 |
+
size 4192763368
|
model-00003-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f0565bf08269e8ba22a27a1755fc781f148e4fab7e2a1f2121c2907529bb4132
|
| 3 |
+
size 4192763368
|
model-00004-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5557526b58edbc86baf900db79073ef9938fe1cf2474913fafdec7e1ab3f02c
|
| 3 |
+
size 4192425256
|
model-00005-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be40f6217badb9fc6b4fd0ebc39631cc4252f036197447eb0a320c0c7b9e4eb7
|
| 3 |
+
size 4192763168
|
model-00006-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f436fb5dc6ab43bc395cf2a8cf8ad7728b0e863d1a8c66c6adccc788ea1582db
|
| 3 |
+
size 4192763368
|
model-00007-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d5148c7451ff907683d7de60b97e96629347d69a61bf77869037e3867ffa52a
|
| 3 |
+
size 4192425264
|
model-00008-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2d895e488085b4a2f0d3d35827e3f22ac9467af78e2eacf2ce192d9e4ffcaa0c
|
| 3 |
+
size 4192763304
|
model-00009-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e62b11956c799e0403384ed0ea41fc31e549cbf04c28c8a0b45891a5af7fd45
|
| 3 |
+
size 4192424720
|
model-00010-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10854aafa195b0ff10f660cd8ecb1ccd461df5bec03d6e49d5ae367fa71c25f0
|
| 3 |
+
size 4192762840
|
model-00011-of-00011.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:87caf2ac32c375a8b165d38cbff951baf10b6edcfd6a3b356ab3cdb12553be85
|
| 3 |
+
size 2257771424
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_ernie4_5_moe.py
ADDED
|
@@ -0,0 +1,1412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from functools import partial
|
| 18 |
+
from typing import Callable, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
|
| 24 |
+
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
| 25 |
+
from transformers.generation import GenerationMixin
|
| 26 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 27 |
+
from transformers.modeling_outputs import ModelOutput, MoeCausalLMOutputWithPast
|
| 28 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 29 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 30 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 31 |
+
from transformers.processing_utils import Unpack
|
| 32 |
+
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging, is_torch_flex_attn_available
|
| 33 |
+
|
| 34 |
+
from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if is_torch_flex_attn_available():
|
| 38 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 39 |
+
|
| 40 |
+
from transformers.integrations.flex_attention import make_flex_block_causal_mask
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class Erine4_5_MoeModelOutputWithPast(ModelOutput):
|
| 49 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 50 |
+
past_key_values: Optional[Cache] = None
|
| 51 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 52 |
+
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 53 |
+
router_loss: Optional[torch.FloatTensor] = None
|
| 54 |
+
gate_logits: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class Ernie4_5_MoeCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
|
| 59 |
+
router_loss: Optional[torch.FloatTensor] = None
|
| 60 |
+
|
| 61 |
+
def rotate_half(x):
|
| 62 |
+
"""Rotates half the hidden dims of the input."""
|
| 63 |
+
|
| 64 |
+
x1 = x[..., 0::2]
|
| 65 |
+
x2 = x[..., 1::2]
|
| 66 |
+
return torch.stack((-x2, x1), dim=-1).reshape(x.shape)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 70 |
+
"""
|
| 71 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 72 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 73 |
+
"""
|
| 74 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 75 |
+
if n_rep == 1:
|
| 76 |
+
return hidden_states
|
| 77 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 78 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 82 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
q (`torch.Tensor`): The query tensor.
|
| 86 |
+
k (`torch.Tensor`): The key tensor.
|
| 87 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 88 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 89 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 90 |
+
Deprecated and unused.
|
| 91 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 92 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 93 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 94 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 95 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 96 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 97 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 98 |
+
Returns:
|
| 99 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 100 |
+
"""
|
| 101 |
+
orig_dtype = q.dtype
|
| 102 |
+
sin_pos = torch.stack([sin, sin], dim=-1).reshape(*sin.shape[:-1],-1)
|
| 103 |
+
cos_pos = torch.stack([cos, cos], dim=-1).reshape(*sin.shape[:-1],-1)
|
| 104 |
+
q_embed = (q.float() * cos_pos) + (rotate_half(q).float() * sin_pos)
|
| 105 |
+
k_embed = (k.float() * cos_pos) + (rotate_half(k).float() * sin_pos)
|
| 106 |
+
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def eager_attention_forward(
|
| 110 |
+
module: nn.Module,
|
| 111 |
+
query: torch.Tensor,
|
| 112 |
+
key: torch.Tensor,
|
| 113 |
+
value: torch.Tensor,
|
| 114 |
+
attention_mask: Optional[torch.Tensor],
|
| 115 |
+
scaling: float,
|
| 116 |
+
dropout: float = 0.0,
|
| 117 |
+
**kwargs,
|
| 118 |
+
):
|
| 119 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 120 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 121 |
+
|
| 122 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 123 |
+
if attention_mask is not None:
|
| 124 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 125 |
+
attn_weights = attn_weights + causal_mask.to(attn_weights.device)
|
| 126 |
+
|
| 127 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 128 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 129 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 130 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 131 |
+
|
| 132 |
+
return attn_output, attn_weights
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def topk_gate_func(
|
| 136 |
+
module: nn.Module,
|
| 137 |
+
hidden_states: torch.Tensor,
|
| 138 |
+
):
|
| 139 |
+
capacity = module.get_capacity(hidden_states.shape[0])
|
| 140 |
+
with torch.autocast(device_type='cuda',dtype=torch.float32):
|
| 141 |
+
logits = module.gate(hidden_states.float())
|
| 142 |
+
router_loss = torch.zeros([1], dtype=torch.float32, device=hidden_states.device)
|
| 143 |
+
router_loss.detach()
|
| 144 |
+
return logits, capacity, router_loss
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class Ernie4_5_ResidualWithDropout(nn.Module):
|
| 148 |
+
"""
|
| 149 |
+
Fused dropout implementation with residual connection support.
|
| 150 |
+
|
| 151 |
+
This layer combines dropout and residual addition in a single operation for better performance,
|
| 152 |
+
particularly on GPU devices. The dropout is conditionally applied based on the probability.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
prob (float): Dropout probability (between 0 and 1)
|
| 156 |
+
|
| 157 |
+
Attributes:
|
| 158 |
+
prob (float): Stores the dropout probability
|
| 159 |
+
dropout (nn.Dropout): The actual dropout layer instance
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(self, prob):
|
| 163 |
+
"""
|
| 164 |
+
Initialize the fused dropout layer.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
prob (float): Dropout probability (0 means no dropout)
|
| 168 |
+
"""
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.prob = prob
|
| 171 |
+
self.dropout = nn.Dropout(p=prob)
|
| 172 |
+
|
| 173 |
+
def forward(self, x, y):
|
| 174 |
+
"""
|
| 175 |
+
Forward pass of the fused dropout layer.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
x (torch.Tensor): Input tensor to potentially apply dropout on
|
| 179 |
+
y (torch.Tensor): Residual tensor to add to the (possibly dropped out) x
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
torch.Tensor: Result of x (with optional dropout) + y
|
| 183 |
+
"""
|
| 184 |
+
if self.prob > 0:
|
| 185 |
+
x = self.dropout(x)
|
| 186 |
+
output = x + y
|
| 187 |
+
|
| 188 |
+
return output
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class Ernie4_5_Attention(nn.Module):
|
| 192 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 193 |
+
|
| 194 |
+
def __init__(self, config, layer_idx=0):
|
| 195 |
+
"""
|
| 196 |
+
Args:
|
| 197 |
+
config (ErnieConfig): Model configuration.
|
| 198 |
+
layer_idx (int, optional): Index in transformer stack. Defaults to 0.
|
| 199 |
+
"""
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.layer_idx = layer_idx
|
| 202 |
+
self.hidden_size = config.hidden_size
|
| 203 |
+
self.num_heads = config.num_attention_heads
|
| 204 |
+
self.num_key_value_heads = config.num_key_value_heads if config.num_key_value_heads is not None else self.nums_head
|
| 205 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 206 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 207 |
+
self.freq_allocation = config.freq_allocation if hasattr(config, "freq_allocation") else 0
|
| 208 |
+
self.scaling = self.head_dim**-0.5
|
| 209 |
+
self.attention_dropout = getattr(config, "attention_probs_dropout_prob", 0.0)
|
| 210 |
+
self.is_causal = True
|
| 211 |
+
|
| 212 |
+
self.q_proj = nn.Linear(
|
| 213 |
+
self.hidden_size,
|
| 214 |
+
self.num_heads * self.head_dim,
|
| 215 |
+
bias=config.use_bias,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
self.k_proj = nn.Linear(
|
| 219 |
+
self.hidden_size,
|
| 220 |
+
self.num_key_value_heads * self.head_dim,
|
| 221 |
+
bias=config.use_bias,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.v_proj = nn.Linear(
|
| 225 |
+
self.hidden_size,
|
| 226 |
+
self.num_key_value_heads * self.head_dim,
|
| 227 |
+
bias=config.use_bias,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
self.o_proj = nn.Linear(
|
| 231 |
+
self.hidden_size,
|
| 232 |
+
self.hidden_size,
|
| 233 |
+
bias=config.use_bias,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
self.config = config
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def forward(
|
| 240 |
+
self,
|
| 241 |
+
hidden_states: torch.Tensor,
|
| 242 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 243 |
+
past_key_value: Optional[Cache] = None,
|
| 244 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 245 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 246 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
|
| 247 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 248 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 249 |
+
B, L = hidden_states.shape[:-1]
|
| 250 |
+
|
| 251 |
+
query_states = self.q_proj(hidden_states).view(B, L, self.num_heads, -1).transpose(1, 2)
|
| 252 |
+
key_states = self.k_proj(hidden_states).view(B, L, self.num_key_value_heads, -1).transpose(1, 2)
|
| 253 |
+
value_states = self.v_proj(hidden_states).view(B, L, self.num_key_value_heads, -1).transpose(1, 2)
|
| 254 |
+
|
| 255 |
+
cos, sin = position_embeddings
|
| 256 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 257 |
+
|
| 258 |
+
if past_key_value is not None:
|
| 259 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 260 |
+
cache_kwargs = {"cache_position": cache_position}
|
| 261 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 262 |
+
|
| 263 |
+
attention_interface: Callable = eager_attention_forward
|
| 264 |
+
if self.config._attn_implementation != "eager":
|
| 265 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 266 |
+
|
| 267 |
+
attn_output, attn_weights = attention_interface(
|
| 268 |
+
self,
|
| 269 |
+
query_states,
|
| 270 |
+
key_states,
|
| 271 |
+
value_states,
|
| 272 |
+
attention_mask,
|
| 273 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 274 |
+
scaling=self.scaling,
|
| 275 |
+
**kwargs,
|
| 276 |
+
)
|
| 277 |
+
attn_output = attn_output.reshape(B, L, -1).contiguous()
|
| 278 |
+
attn_output = self.o_proj(attn_output)
|
| 279 |
+
|
| 280 |
+
return attn_output, attn_weights
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class Ernie4_5_MLP(nn.Module):
|
| 284 |
+
"""
|
| 285 |
+
Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model.
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
def __init__(self, config,intermediate_size=None):
|
| 289 |
+
"""
|
| 290 |
+
Initialize the MLP module with configuration options.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
config: Model configuration object with attributes:
|
| 294 |
+
- hidden_size: int
|
| 295 |
+
- intermediate_size: int
|
| 296 |
+
- use_bias: bool
|
| 297 |
+
layer_idx (int): Index of current layer (default: 0)
|
| 298 |
+
"""
|
| 299 |
+
super().__init__()
|
| 300 |
+
self.config = config
|
| 301 |
+
self.hidden_size = config.hidden_size
|
| 302 |
+
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
| 303 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
|
| 304 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
|
| 305 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def forward(self, x):
|
| 309 |
+
"""
|
| 310 |
+
Args:
|
| 311 |
+
x (Tensor): shape [batch_size, seq_len, hidden_size]
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
Tensor: shape [batch_size, seq_len, hidden_size]
|
| 315 |
+
"""
|
| 316 |
+
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 317 |
+
return down_proj
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class Ernie4_5_MoeStatics(nn.Module):
|
| 321 |
+
"""
|
| 322 |
+
Stores MoE (Mixture of Experts) statistics
|
| 323 |
+
and expert usage information.
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
def __init__(self, config):
|
| 327 |
+
"""
|
| 328 |
+
Initialize MoE statistics tracking.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
config: Model configuration containing MoE parameters
|
| 332 |
+
"""
|
| 333 |
+
super().__init__()
|
| 334 |
+
|
| 335 |
+
num_experts = config.moe_num_experts
|
| 336 |
+
num_experts_groups = 1
|
| 337 |
+
|
| 338 |
+
self.e_score_correction_bias = nn.Parameter(
|
| 339 |
+
torch.zeros(num_experts_groups, num_experts, dtype=torch.float32),
|
| 340 |
+
requires_grad=False
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
class Ernie4_5_MoeMLP(nn.Module):
|
| 344 |
+
"""Mixture of Experts (MoE) variant of ERNIE's MLP layer."""
|
| 345 |
+
|
| 346 |
+
def __init__(self,config):
|
| 347 |
+
super().__init__()
|
| 348 |
+
self.config = config
|
| 349 |
+
self.k = config.moe_k
|
| 350 |
+
self.sinkhorn_2gate = config.sinkhorn_2gate
|
| 351 |
+
self.sinkhorn_temp = config.sinkhorn_temp
|
| 352 |
+
|
| 353 |
+
moe_intermediate_size = config.moe_intermediate_size if config.moe_intermediate_size else config.intermediate_size
|
| 354 |
+
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
|
| 355 |
+
if config.moe_gate_act == "softmax":
|
| 356 |
+
self.gate_act = partial(F.softmax, dim=-1)
|
| 357 |
+
elif config.moe_gate_act == "sigmoid":
|
| 358 |
+
self.gate_act = F.sigmoid
|
| 359 |
+
else:
|
| 360 |
+
raise ValueError(f"{config.moe_gate_act} is not supported.")
|
| 361 |
+
|
| 362 |
+
self.experts = nn.ModuleList(
|
| 363 |
+
[Ernie4_5_MLP(config,moe_intermediate_size) for i in range(config.moe_num_experts)]
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if config.moe_use_aux_free:
|
| 367 |
+
self.moe_statics = Ernie4_5_MoeStatics(config)
|
| 368 |
+
|
| 369 |
+
self.use_correction_bias = config.moe_use_aux_free
|
| 370 |
+
self.num_local_experts = len(self.experts)
|
| 371 |
+
|
| 372 |
+
self.shared_experts = self._init_shared_experts()
|
| 373 |
+
|
| 374 |
+
def _init_shared_experts(self):
|
| 375 |
+
"""
|
| 376 |
+
Initialize the shared expert module.
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
shared_experts: Shared expert module, returns None if no shared experts are needed.
|
| 380 |
+
|
| 381 |
+
"""
|
| 382 |
+
cfg = deepcopy(self.config)
|
| 383 |
+
if getattr(cfg, 'moe_num_shared_experts', 0) > 0:
|
| 384 |
+
if getattr(cfg, 'moe_intermediate_size', None):
|
| 385 |
+
cfg.intermediate_size = cfg.moe_intermediate_size * cfg.moe_num_shared_experts
|
| 386 |
+
else:
|
| 387 |
+
cfg.intermediate_size = cfg.intermediate_size * cfg.moe_num_shared_experts
|
| 388 |
+
shared_experts = Ernie4_5_MLP(cfg, cfg.intermediate_size)
|
| 389 |
+
else:
|
| 390 |
+
shared_experts = None
|
| 391 |
+
return shared_experts
|
| 392 |
+
|
| 393 |
+
def forward(
|
| 394 |
+
self,
|
| 395 |
+
input: torch.Tensor,
|
| 396 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 397 |
+
"""
|
| 398 |
+
Forward pass through MoE layer.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
input (Tensor): Input tensor of shape [s, d].
|
| 402 |
+
token_type_ids: Optional tensor for token types.
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
tuple: (output, combine_weights, router_loss, gate_logits)
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
if input.dim() == 3:
|
| 409 |
+
orig_shape = input.shape
|
| 410 |
+
input = input.reshape(-1, input.shape[-1])
|
| 411 |
+
else:
|
| 412 |
+
orig_shape = None
|
| 413 |
+
assert input.dim() == 2, f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}"
|
| 414 |
+
|
| 415 |
+
assert self.gate is not None
|
| 416 |
+
|
| 417 |
+
gate_input = input
|
| 418 |
+
|
| 419 |
+
(
|
| 420 |
+
dispatched_input,
|
| 421 |
+
combine_weights,
|
| 422 |
+
dispatch_mask,
|
| 423 |
+
scatter_index,
|
| 424 |
+
router_loss,
|
| 425 |
+
gate_logits,
|
| 426 |
+
gate_prob
|
| 427 |
+
) = self.gate_and_dispatch(gate_input)
|
| 428 |
+
|
| 429 |
+
expert_out = self.forward_experts(dispatched_input)
|
| 430 |
+
|
| 431 |
+
combined_output = self.combine_expert_output(expert_out, combine_weights, scatter_index)
|
| 432 |
+
|
| 433 |
+
if self.shared_experts is not None:
|
| 434 |
+
shared_expert_out = self.shared_experts(gate_input)
|
| 435 |
+
combined_output += shared_expert_out
|
| 436 |
+
|
| 437 |
+
if orig_shape:
|
| 438 |
+
combined_output = combined_output.reshape(orig_shape[:-1] + (combined_output.shape[-1],))
|
| 439 |
+
|
| 440 |
+
return combined_output, combine_weights, router_loss, gate_logits
|
| 441 |
+
|
| 442 |
+
def forward_experts(self, dispatched_input: torch.Tensor) -> torch.Tensor:
|
| 443 |
+
"""
|
| 444 |
+
Forward pass through experts sequentially.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
dispatched_input (Tensor): Input tensor of shape [num_experts, capacity, dim].
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
Tensor: Expert outputs of shape [num_experts, capacity, dim].
|
| 451 |
+
"""
|
| 452 |
+
true_experts = self.experts
|
| 453 |
+
dispatched_input = dispatched_input.reshape(
|
| 454 |
+
1, self.num_local_experts, -1, dispatched_input.shape[-1]
|
| 455 |
+
)
|
| 456 |
+
expert_outputs = []
|
| 457 |
+
if isinstance(self.experts, nn.ModuleList):
|
| 458 |
+
chunks = dispatched_input.permute(1, 0, 2, 3).contiguous().unbind(0)
|
| 459 |
+
assert len(chunks) == len(true_experts), f"{len(chunks)}, {len(true_experts)}"
|
| 460 |
+
for chunk, expert in zip(chunks, true_experts):
|
| 461 |
+
expert_outputs.append(expert(chunk))
|
| 462 |
+
else:
|
| 463 |
+
dispatched_input = dispatched_input.permute(1, 0, 2, 3).contiguous()
|
| 464 |
+
orig_shape = dispatched_input.shape
|
| 465 |
+
chunks = dispatched_input.reshape(orig_shape[0], -1, orig_shape[-1])
|
| 466 |
+
chunks = self.experts(chunks)
|
| 467 |
+
chunks = chunks.reshape(orig_shape[:-1] + (chunks.shape[-1],)).unbind(0)
|
| 468 |
+
expert_outputs.extend(chunks)
|
| 469 |
+
|
| 470 |
+
expert_output = torch.stack(expert_outputs, dim=1)
|
| 471 |
+
return expert_output
|
| 472 |
+
|
| 473 |
+
def moe_gate_dispatch(
|
| 474 |
+
self,
|
| 475 |
+
x: torch.Tensor,
|
| 476 |
+
gate_logits: torch.Tensor,
|
| 477 |
+
k: int,
|
| 478 |
+
capacity: Optional[int],
|
| 479 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
| 480 |
+
torch.Tensor, torch.Tensor]:
|
| 481 |
+
|
| 482 |
+
S, H = x.shape
|
| 483 |
+
E = gate_logits.shape[1]
|
| 484 |
+
device = x.device
|
| 485 |
+
topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1)
|
| 486 |
+
combine_weights = topk_prob
|
| 487 |
+
expert_id = topk_idx
|
| 488 |
+
y = x.new_zeros((E, capacity, H))
|
| 489 |
+
scatter_index = x.new_full((k, S), -1, dtype=torch.int32)
|
| 490 |
+
|
| 491 |
+
# per-expert slot counters
|
| 492 |
+
slot_counter = torch.zeros(E, dtype=torch.int32, device=device)
|
| 493 |
+
|
| 494 |
+
for tok in range(S):
|
| 495 |
+
for route in range(k):
|
| 496 |
+
e = expert_id[tok, route].item()
|
| 497 |
+
slot = slot_counter[e].item()
|
| 498 |
+
if slot >= capacity:
|
| 499 |
+
combine_weights[tok, route] = 0.0
|
| 500 |
+
continue
|
| 501 |
+
|
| 502 |
+
# record mapping & dispatch activation
|
| 503 |
+
scatter_index[route, tok] = e * capacity + slot
|
| 504 |
+
y[e, slot] = x[tok]
|
| 505 |
+
slot_counter[e] += 1
|
| 506 |
+
|
| 507 |
+
expert_offset = torch.cumsum(slot_counter, 0, dtype=torch.int64)
|
| 508 |
+
|
| 509 |
+
return y, combine_weights, scatter_index, expert_offset, expert_id
|
| 510 |
+
|
| 511 |
+
def combine_expert_output(self, expert_output: torch.Tensor, combine_weights: torch.Tensor, scatter_index: torch.Tensor) -> torch.Tensor:
|
| 512 |
+
"""
|
| 513 |
+
Combine expert outputs using combination weights.
|
| 514 |
+
|
| 515 |
+
Args:
|
| 516 |
+
expert_output (Tensor): Expert outputs [num_experts, capacity, dim].
|
| 517 |
+
combine_weights (Tensor): Combination weights.
|
| 518 |
+
scatter_index (Tensor): Scatter indices.
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
Tensor: Combined output [seqlen, dim].
|
| 522 |
+
"""
|
| 523 |
+
expert_output = expert_output.reshape(-1, expert_output.shape[-1])
|
| 524 |
+
combined_output = self.combining(expert_output, combine_weights, scatter_index)
|
| 525 |
+
return combined_output
|
| 526 |
+
|
| 527 |
+
def combining(self, x, combine_weights, scatter_index):
|
| 528 |
+
"""
|
| 529 |
+
Combines and aggregates input matrix using combination weights.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
x (Tensor): Input tensor of shape [num_experts * capacity, dim]
|
| 533 |
+
combine_weights (Tensor): Combination weights of shape [seq, 2]
|
| 534 |
+
scatter_index (Tensor): Scatter indices of shape [seq, 2]
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
Tensor: Combined output tensor of shape [seq, dim]
|
| 538 |
+
"""
|
| 539 |
+
dim = x.shape[-1]
|
| 540 |
+
|
| 541 |
+
scatter_index = scatter_index.reshape([-1])
|
| 542 |
+
num_k = combine_weights.shape[-1]
|
| 543 |
+
|
| 544 |
+
combine_weights = combine_weights.unsqueeze(1)
|
| 545 |
+
|
| 546 |
+
x = x[scatter_index].reshape([-1, num_k, dim])
|
| 547 |
+
|
| 548 |
+
return torch.matmul(combine_weights, x).squeeze(1)
|
| 549 |
+
|
| 550 |
+
def gate_and_dispatch(self, input):
|
| 551 |
+
"""
|
| 552 |
+
Calculate gate and dispatch inputs.
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
input: Input tensor of shape [seq, dim]
|
| 556 |
+
|
| 557 |
+
Returns:
|
| 558 |
+
tuple: (dispatched_input, combine_weights, dispatch_mask,
|
| 559 |
+
scatter_index, router_loss, gate_logits, gate_prob)
|
| 560 |
+
"""
|
| 561 |
+
gate_logits, capacity, router_loss = topk_gate_func(self, input)
|
| 562 |
+
|
| 563 |
+
# capacity no use
|
| 564 |
+
prob = self.gate_act(gate_logits)
|
| 565 |
+
(
|
| 566 |
+
dispatched_input,
|
| 567 |
+
combine_weights_unnorm,
|
| 568 |
+
scatter_index,
|
| 569 |
+
dispatch_mask,
|
| 570 |
+
_,
|
| 571 |
+
) = self.moe_gate_dispatch(input, prob, k=self.k, capacity=capacity)
|
| 572 |
+
dispatch_mask = torch.diff(F.pad(dispatch_mask, (1, 0)))
|
| 573 |
+
|
| 574 |
+
scatter_index.detach()
|
| 575 |
+
dispatch_mask.detach()
|
| 576 |
+
|
| 577 |
+
scatter_index = scatter_index.transpose(0, 1) # [k, s] -> [s, k]
|
| 578 |
+
combine_weights = combine_weights_unnorm / torch.clamp(
|
| 579 |
+
combine_weights_unnorm.sum(dim=-1, keepdim=True), min=1e-12
|
| 580 |
+
)
|
| 581 |
+
combine_weights = combine_weights.to(dtype=dispatched_input.dtype)
|
| 582 |
+
|
| 583 |
+
return dispatched_input, combine_weights, dispatch_mask, scatter_index, router_loss, gate_logits, prob
|
| 584 |
+
|
| 585 |
+
def get_capacity(self, num_tokens, cap_factor=None):
|
| 586 |
+
"""
|
| 587 |
+
Calculate capacity based on number of tokens.
|
| 588 |
+
|
| 589 |
+
Args:
|
| 590 |
+
num_tokens: Number of input tokens
|
| 591 |
+
cap_factor: Optional capacity factor override
|
| 592 |
+
|
| 593 |
+
Returns:
|
| 594 |
+
int: Calculated capacity
|
| 595 |
+
"""
|
| 596 |
+
num_experts = self.config.moe_num_experts
|
| 597 |
+
if cap_factor is not None:
|
| 598 |
+
cap = cap_factor
|
| 599 |
+
else:
|
| 600 |
+
if self.training:
|
| 601 |
+
cap = self.config.moe_capacity[0]
|
| 602 |
+
elif num_tokens < num_experts:
|
| 603 |
+
cap = self.config.moe_capacity[2]
|
| 604 |
+
else:
|
| 605 |
+
cap = self.config.moe_capacity[1]
|
| 606 |
+
|
| 607 |
+
capacity = int(cap * num_tokens // num_experts)
|
| 608 |
+
assert capacity > 0, f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}"
|
| 609 |
+
return capacity
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class Ernie4_5_RMSNorm(nn.Module):
|
| 613 |
+
"""
|
| 614 |
+
Ernie Root Mean Square Layer Normalization (Ernie4_5_RMSNorm) implementation.
|
| 615 |
+
|
| 616 |
+
Ernie4_5_RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs,
|
| 617 |
+
omitting the mean-centering operation. This provides computational efficiency while maintaining
|
| 618 |
+
good performance.
|
| 619 |
+
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
def __init__(self, config):
|
| 623 |
+
"""
|
| 624 |
+
Initialize RMSNorm layer.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
config (ErnieConfig): Model configuration.
|
| 628 |
+
"""
|
| 629 |
+
super().__init__()
|
| 630 |
+
self.config = config
|
| 631 |
+
self.hidden_size = config.hidden_size
|
| 632 |
+
self.weight = nn.Parameter(torch.ones(config.hidden_size))
|
| 633 |
+
self.variance_epsilon = config.rms_norm_eps
|
| 634 |
+
|
| 635 |
+
def forward(self, hidden_states):
|
| 636 |
+
"""
|
| 637 |
+
Apply RMS normalization to input hidden states.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
| 641 |
+
|
| 642 |
+
Returns:
|
| 643 |
+
Tensor: Normalized output tensor of same shape as input
|
| 644 |
+
"""
|
| 645 |
+
input_dtype = hidden_states.dtype
|
| 646 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 647 |
+
variance = hidden_states.pow(2).mean(dim=-1, keepdim=True)
|
| 648 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 649 |
+
|
| 650 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
class Ernie4_5_RopeEmbedding(nn.Module):
|
| 654 |
+
def __init__(self, config: Ernie4_5_MoeConfig, device=None):
|
| 655 |
+
super().__init__()
|
| 656 |
+
# BC: "rope_type" was originally "type"
|
| 657 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 658 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 659 |
+
else:
|
| 660 |
+
self.rope_type = "default"
|
| 661 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 662 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 663 |
+
|
| 664 |
+
self.config = config
|
| 665 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 666 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 667 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 668 |
+
self.original_inv_freq = self.inv_freq
|
| 669 |
+
|
| 670 |
+
@torch.no_grad()
|
| 671 |
+
def forward(self, x, position_ids):
|
| 672 |
+
inv_freq_expanded = self.inv_freq[None,None,:].float()
|
| 673 |
+
position_ids_expanded = position_ids[...,None].float()
|
| 674 |
+
freqs = (inv_freq_expanded.float() * position_ids_expanded.float())
|
| 675 |
+
cos = torch.cos(freqs) * self.attention_scaling
|
| 676 |
+
sin = torch.sin(freqs) * self.attention_scaling
|
| 677 |
+
return cos, sin
|
| 678 |
+
# return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
class Ernie4_5_DecoderLayer(nn.Module):
|
| 682 |
+
"""A single transformer decoder layer in ERNIE-MoE model.
|
| 683 |
+
|
| 684 |
+
Contains self-attention and feed-forward components with optional MoE (Mixture of Experts)
|
| 685 |
+
support, residual connections, and layer normalization.
|
| 686 |
+
"""
|
| 687 |
+
|
| 688 |
+
def __init__(self, config, layer_idx):
|
| 689 |
+
"""Initialize the decoder layer.
|
| 690 |
+
|
| 691 |
+
Args:
|
| 692 |
+
config (ErnieMoEConfig): Model configuration.
|
| 693 |
+
layer_idx (int): Index of this layer in the transformer stack
|
| 694 |
+
"""
|
| 695 |
+
super().__init__()
|
| 696 |
+
self.hidden_size = config.hidden_size
|
| 697 |
+
self.layer_idx = layer_idx
|
| 698 |
+
self.config = config
|
| 699 |
+
self.use_moe = config.use_moe
|
| 700 |
+
self.self_attn = Ernie4_5_Attention(config, layer_idx)
|
| 701 |
+
|
| 702 |
+
moe_layer_start_index = (
|
| 703 |
+
min(config.moe_layer_start_index)
|
| 704 |
+
if isinstance(config.moe_layer_start_index, (tuple, list))
|
| 705 |
+
else config.moe_layer_start_index
|
| 706 |
+
)
|
| 707 |
+
moe_layer_end_index = (
|
| 708 |
+
max(config.moe_layer_end_index)
|
| 709 |
+
if isinstance(config.moe_layer_end_index, (tuple, list))
|
| 710 |
+
else config.moe_layer_end_index
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
if (
|
| 714 |
+
self.use_moe
|
| 715 |
+
and ((layer_idx + 1) % config.moe_layer_interval == 0)
|
| 716 |
+
and layer_idx >= moe_layer_start_index
|
| 717 |
+
and layer_idx <= moe_layer_end_index
|
| 718 |
+
):
|
| 719 |
+
self.mlp = Ernie4_5_MoeMLP(config)
|
| 720 |
+
else:
|
| 721 |
+
self.mlp = Ernie4_5_MLP(config)
|
| 722 |
+
|
| 723 |
+
self.input_layernorm = Ernie4_5_RMSNorm(config)
|
| 724 |
+
self.post_attention_layernorm = Ernie4_5_RMSNorm(config)
|
| 725 |
+
|
| 726 |
+
self.residual_add1 = Ernie4_5_ResidualWithDropout(config.hidden_dropout_prob)
|
| 727 |
+
self.residual_add2 = Ernie4_5_ResidualWithDropout(config.hidden_dropout_prob)
|
| 728 |
+
|
| 729 |
+
def forward(
|
| 730 |
+
self,
|
| 731 |
+
hidden_states: torch.Tensor,
|
| 732 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 733 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 734 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 735 |
+
output_attentions: Optional[bool] = False,
|
| 736 |
+
use_cache: Optional[bool] = False,
|
| 737 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 738 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 739 |
+
output_router_loss: bool = True,
|
| 740 |
+
output_gate_logits: bool = True,
|
| 741 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 742 |
+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 743 |
+
"""Forward pass through the decoder layer.
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size]
|
| 747 |
+
attention_mask (Optional[torch.Tensor]): Attention mask tensor
|
| 748 |
+
position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings
|
| 749 |
+
past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states
|
| 750 |
+
output_attentions (Optional[bool]): Whether to return attention weights
|
| 751 |
+
use_cache (Optional[bool]): Whether to cache key/value states
|
| 752 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 753 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 754 |
+
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 755 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 756 |
+
with `head_dim` being the embedding dimension of each attention head.
|
| 757 |
+
output_router_loss (bool): Whether to return MoE router loss
|
| 758 |
+
output_gate_logits (bool): Whether to return MoE gate logits
|
| 759 |
+
|
| 760 |
+
Returns:
|
| 761 |
+
Union: Various output combinations depending on arguments:
|
| 762 |
+
- Base case: Hidden states tensor
|
| 763 |
+
- With attention: Tuple of (hidden_states, attention_weights)
|
| 764 |
+
- With router loss: May include gate logits in output tuple
|
| 765 |
+
- With MoE gate logits: May include gate logits in output tuple
|
| 766 |
+
"""
|
| 767 |
+
residual = hidden_states
|
| 768 |
+
|
| 769 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 770 |
+
|
| 771 |
+
# Self Attention
|
| 772 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 773 |
+
hidden_states=hidden_states,
|
| 774 |
+
attention_mask=attention_mask,
|
| 775 |
+
past_key_value=past_key_value,
|
| 776 |
+
position_ids=position_ids,
|
| 777 |
+
use_cache=use_cache,
|
| 778 |
+
cache_position=cache_position,
|
| 779 |
+
position_embeddings=position_embeddings,
|
| 780 |
+
**kwargs,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
hidden_states = self.residual_add1(hidden_states, residual)
|
| 784 |
+
|
| 785 |
+
# Fully Connected
|
| 786 |
+
residual = hidden_states
|
| 787 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 788 |
+
|
| 789 |
+
router_loss = None
|
| 790 |
+
gate_logits = None
|
| 791 |
+
|
| 792 |
+
if isinstance(self.mlp, Ernie4_5_MoeMLP):
|
| 793 |
+
hidden_states, _, router_loss, gate_logits = self.mlp(hidden_states)
|
| 794 |
+
else:
|
| 795 |
+
hidden_states = self.mlp(hidden_states)
|
| 796 |
+
|
| 797 |
+
hidden_states = self.residual_add2(hidden_states, residual)
|
| 798 |
+
|
| 799 |
+
outputs = (hidden_states,)
|
| 800 |
+
|
| 801 |
+
if output_attentions:
|
| 802 |
+
outputs += (self_attn_weights,)
|
| 803 |
+
|
| 804 |
+
if output_router_loss:
|
| 805 |
+
outputs += (router_loss,)
|
| 806 |
+
|
| 807 |
+
if output_gate_logits:
|
| 808 |
+
outputs += (gate_logits,)
|
| 809 |
+
|
| 810 |
+
return outputs
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
@auto_docstring
|
| 814 |
+
class Ernie4_5_PretrainedModel(PreTrainedModel):
|
| 815 |
+
"""Base class for ERNIE pretrained models."""
|
| 816 |
+
config_class = Ernie4_5_MoeConfig
|
| 817 |
+
base_model_prefix = "model"
|
| 818 |
+
supports_gradient_checkpointing = True
|
| 819 |
+
_no_split_modules = ["Ernie4_5_DecoderLayer"]
|
| 820 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 821 |
+
_supports_flash_attn_2 = True
|
| 822 |
+
_supports_sdpa = True
|
| 823 |
+
_supports_flex_attn = True
|
| 824 |
+
_supports_cache_class = True
|
| 825 |
+
_supports_quantized_cache = True
|
| 826 |
+
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def subbatch(f, arg_idx, axis, bs, out_idx, same_arg_idx={}):
|
| 830 |
+
"""
|
| 831 |
+
Converts a function to one that applies to subbatch of an input dimension.
|
| 832 |
+
Useful for processing large tensors in smaller chunks to reduce memory usage.
|
| 833 |
+
|
| 834 |
+
Args:
|
| 835 |
+
f (Callable): Function to be subbatched.
|
| 836 |
+
arg_idx ([int]): Indices of the inputs to be subbatched.
|
| 837 |
+
axis ([int]): Indices of the dimensions to be subbatched for each input.
|
| 838 |
+
bs (int): Subbatch size.
|
| 839 |
+
out_idx (int): Dimension to concatenate outputs along.
|
| 840 |
+
same_arg_idx (dict): Mapping of argument indices that share the same tensor.
|
| 841 |
+
|
| 842 |
+
Returns:
|
| 843 |
+
Callable: New function that processes inputs in subbatches.
|
| 844 |
+
"""
|
| 845 |
+
|
| 846 |
+
@functools.wraps(f)
|
| 847 |
+
def wrapper(*args, **kwargs):
|
| 848 |
+
|
| 849 |
+
assert len(arg_idx) == len(axis), "Number of batching args and number of batching dims should match."
|
| 850 |
+
|
| 851 |
+
inps = [args[i] for i in arg_idx]
|
| 852 |
+
axis_width = [inp.shape[d] for inp, d in zip(inps, axis)]
|
| 853 |
+
assert len(set(axis_width)) == 1, "Batch sizes should be kept equal."
|
| 854 |
+
|
| 855 |
+
inp_axis = {idx: d for idx, d in zip(arg_idx, axis)}
|
| 856 |
+
|
| 857 |
+
axis_width = axis_width[0]
|
| 858 |
+
if axis_width < bs:
|
| 859 |
+
return f(*args, **kwargs)
|
| 860 |
+
|
| 861 |
+
outs = []
|
| 862 |
+
for slice_at in range(0, axis_width, bs):
|
| 863 |
+
_args = []
|
| 864 |
+
for i, inp in enumerate(args):
|
| 865 |
+
if i in same_arg_idx:
|
| 866 |
+
assert (
|
| 867 |
+
i > same_arg_idx[i]
|
| 868 |
+
), f"expect i > same_arg_idx[i], but got i: {i} and same_arg_idx[i]: {same_arg_idx[i]}"
|
| 869 |
+
_args.append(_args[same_arg_idx[i]])
|
| 870 |
+
elif i in arg_idx:
|
| 871 |
+
d = inp_axis[i]
|
| 872 |
+
start = slice_at
|
| 873 |
+
end = min(inp.shape[d], slice_at + bs)
|
| 874 |
+
# Build slice for all dims, only slice along axis d
|
| 875 |
+
slices = [slice(None)] * inp.ndim
|
| 876 |
+
slices[d] = slice(start, end)
|
| 877 |
+
_args.append(inp[tuple(slices)])
|
| 878 |
+
else:
|
| 879 |
+
_args.append(inp)
|
| 880 |
+
|
| 881 |
+
out = f(*_args, **kwargs)
|
| 882 |
+
outs.append(out)
|
| 883 |
+
|
| 884 |
+
return torch.cat(outs, dim=out_idx)
|
| 885 |
+
|
| 886 |
+
return wrapper
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
class ErniePretrainingCriterion(nn.Module):
|
| 890 |
+
"""Criterion for ERNIE pretraining task."""
|
| 891 |
+
|
| 892 |
+
def __init__(self, config, return_tuple=True):
|
| 893 |
+
"""Initialize the pretraining criterion.
|
| 894 |
+
|
| 895 |
+
Args:
|
| 896 |
+
config (ErnieConfig): Model configuration.
|
| 897 |
+
return_tuple (bool): Whether to return loss as tuple (loss, loss_sum). Defaults to True.
|
| 898 |
+
"""
|
| 899 |
+
super().__init__()
|
| 900 |
+
self.ignored_index = getattr(config, "ignored_index", -100)
|
| 901 |
+
self.config = config
|
| 902 |
+
self.return_tuple = return_tuple
|
| 903 |
+
|
| 904 |
+
self.loss_func = nn.CrossEntropyLoss(reduction="none")
|
| 905 |
+
|
| 906 |
+
def forward(self, prediction_scores, masked_lm_labels, loss_mask, router_loss=None):
|
| 907 |
+
"""Compute the combined pretraining loss.
|
| 908 |
+
|
| 909 |
+
Args:
|
| 910 |
+
prediction_scores: Prediction scores tensor, [batch_size, seq_len, vocab_size]
|
| 911 |
+
masked_lm_labels: Target labels tensor [batch_size, seq_len]
|
| 912 |
+
loss_mask: Optional mask for valid tokens
|
| 913 |
+
router_loss: Optional MoE router loss tensor
|
| 914 |
+
|
| 915 |
+
Returns:
|
| 916 |
+
Union:
|
| 917 |
+
- If return_tuple=True: Tuple of (combined_loss, mlm_loss_sum)
|
| 918 |
+
- If return_tuple=False: Combined loss tensor
|
| 919 |
+
"""
|
| 920 |
+
res = self.forward_impl(prediction_scores, masked_lm_labels, loss_mask)
|
| 921 |
+
|
| 922 |
+
if self.return_tuple:
|
| 923 |
+
loss, loss_sum = res
|
| 924 |
+
else:
|
| 925 |
+
loss, loss_sum = res, None
|
| 926 |
+
|
| 927 |
+
if router_loss is not None and isinstance(router_loss, torch.Tensor):
|
| 928 |
+
loss = loss + router_loss - router_loss.detach()
|
| 929 |
+
|
| 930 |
+
return loss, loss_sum
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
def loss_impl(self, prediction_scores: torch.Tensor, masked_lm_labels: torch.Tensor) -> torch.Tensor:
|
| 934 |
+
"""
|
| 935 |
+
Core loss computation without reduction (but per-token).
|
| 936 |
+
|
| 937 |
+
Args:
|
| 938 |
+
prediction_scores (torch.Tensor): Logits tensor [batch_size, seq_len, vocab_size].
|
| 939 |
+
masked_lm_labels (torch.Tensor): Target labels tensor [batch_size, seq_len].
|
| 940 |
+
|
| 941 |
+
Returns:
|
| 942 |
+
torch.Tensor: Unreduced loss tensor of shape [batch_size, seq_len].
|
| 943 |
+
Losses are calculated in float32.
|
| 944 |
+
"""
|
| 945 |
+
scores_float32 = prediction_scores.to(torch.float32)
|
| 946 |
+
# prediction_scores: [batch_size, seq_len, vocab_size]
|
| 947 |
+
# masked_lm_labels: [batch_size, seq_len]
|
| 948 |
+
# Transpose prediction_scores to [batch_size, vocab_size, seq_len]
|
| 949 |
+
unreduced_loss = self.loss_func(
|
| 950 |
+
scores_float32.transpose(1, 2), # Shape: [batch_size, vocab_size, seq_len]
|
| 951 |
+
masked_lm_labels.long() # Shape: [batch_size, seq_len], ensure long type
|
| 952 |
+
)
|
| 953 |
+
# unreduced_loss will be of shape [batch_size, seq_len] and dtype float32
|
| 954 |
+
return unreduced_loss
|
| 955 |
+
|
| 956 |
+
def forward_impl(self, prediction_scores, masked_lm_labels, loss_mask=None):
|
| 957 |
+
prediction_scores_dims = len(prediction_scores.shape)
|
| 958 |
+
|
| 959 |
+
loss_subbatch_seqlen_config_key = "loss_subbatch_seqlen"
|
| 960 |
+
default_loss_subbatch_seqlen = 32768
|
| 961 |
+
|
| 962 |
+
current_loss_subbatch_seqlen = self.config.get(
|
| 963 |
+
loss_subbatch_seqlen_config_key, default_loss_subbatch_seqlen
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
if prediction_scores_dims == 2 and prediction_scores.shape[0] > current_loss_subbatch_seqlen:
|
| 967 |
+
sb_loss_func = subbatch(
|
| 968 |
+
self.loss_impl, [0, 1], [0, 0], current_loss_subbatch_seqlen, 0
|
| 969 |
+
)
|
| 970 |
+
masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
|
| 971 |
+
elif prediction_scores_dims == 3 and prediction_scores.shape[1] > current_loss_subbatch_seqlen:
|
| 972 |
+
sb_loss_func = subbatch(
|
| 973 |
+
self.loss_impl, [0, 1], [1, 1], current_loss_subbatch_seqlen, 1
|
| 974 |
+
)
|
| 975 |
+
masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
|
| 976 |
+
else:
|
| 977 |
+
masked_lm_loss = self.loss_impl(prediction_scores, masked_lm_labels)
|
| 978 |
+
|
| 979 |
+
if loss_mask is None:
|
| 980 |
+
loss_mask = masked_lm_labels != self.ignored_index
|
| 981 |
+
|
| 982 |
+
loss_mask = loss_mask.reshape(-1).to(torch.float32)
|
| 983 |
+
|
| 984 |
+
masked_lm_loss = torch.sum(masked_lm_loss.to(torch.float32).reshape(-1) * loss_mask)
|
| 985 |
+
|
| 986 |
+
# The division will be in float32
|
| 987 |
+
loss = masked_lm_loss / loss_mask.sum()
|
| 988 |
+
|
| 989 |
+
loss_sum = masked_lm_loss.sum().detach()
|
| 990 |
+
|
| 991 |
+
if not self.return_tuple:
|
| 992 |
+
if self.training:
|
| 993 |
+
return loss
|
| 994 |
+
return loss_sum
|
| 995 |
+
return loss, loss_sum
|
| 996 |
+
|
| 997 |
+
@auto_docstring
|
| 998 |
+
class Ernie4_5_Model(Ernie4_5_PretrainedModel):
|
| 999 |
+
"""The core ERNIE transformer model with MoE (Mixture of Experts) support."""
|
| 1000 |
+
_keep_in_fp32_modules = ['gate']
|
| 1001 |
+
def __init__(self, config: Ernie4_5_MoeConfig):
|
| 1002 |
+
"""Initialize the ERNIE model architecture."""
|
| 1003 |
+
super().__init__(config)
|
| 1004 |
+
self.padding_idx = config.pad_token_id
|
| 1005 |
+
self.vocab_size = config.vocab_size
|
| 1006 |
+
self.hidden_size = config.hidden_size
|
| 1007 |
+
self.config = config
|
| 1008 |
+
|
| 1009 |
+
self.embed_tokens = nn.Embedding(
|
| 1010 |
+
self.vocab_size,
|
| 1011 |
+
self.hidden_size,
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
self.layers = nn.ModuleList(
|
| 1015 |
+
[
|
| 1016 |
+
Ernie4_5_DecoderLayer(config, i)
|
| 1017 |
+
for i in range(config.num_hidden_layers)
|
| 1018 |
+
]
|
| 1019 |
+
)
|
| 1020 |
+
self.norm = Ernie4_5_RMSNorm(config)
|
| 1021 |
+
self.rotary_emb = Ernie4_5_RopeEmbedding(config=config)
|
| 1022 |
+
|
| 1023 |
+
self.gradient_checkpointing = False
|
| 1024 |
+
|
| 1025 |
+
self.post_init()
|
| 1026 |
+
|
| 1027 |
+
def get_input_embeddings(self):
|
| 1028 |
+
"""Get the input embedding layer."""
|
| 1029 |
+
return self.embed_tokens
|
| 1030 |
+
|
| 1031 |
+
def set_input_embeddings(self, value):
|
| 1032 |
+
"""Set new input embeddings."""
|
| 1033 |
+
self.embed_tokens = value
|
| 1034 |
+
|
| 1035 |
+
def forward(
|
| 1036 |
+
self,
|
| 1037 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1038 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1039 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1040 |
+
past_key_values: Optional[Cache] = None,
|
| 1041 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1042 |
+
use_cache: Optional[bool] = None,
|
| 1043 |
+
output_attentions: Optional[bool] = None,
|
| 1044 |
+
output_hidden_states: Optional[bool] = None,
|
| 1045 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1046 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 1047 |
+
):
|
| 1048 |
+
"""Forward pass through the ERNIE model."""
|
| 1049 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1050 |
+
output_hidden_states = (
|
| 1051 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1055 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1056 |
+
|
| 1057 |
+
if self.gradient_checkpointing and self.training:
|
| 1058 |
+
if use_cache:
|
| 1059 |
+
logger.warning_once(
|
| 1060 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1061 |
+
)
|
| 1062 |
+
use_cache = False
|
| 1063 |
+
|
| 1064 |
+
if use_cache and past_key_values is None:
|
| 1065 |
+
past_key_values = DynamicCache()
|
| 1066 |
+
|
| 1067 |
+
if inputs_embeds is None:
|
| 1068 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 1069 |
+
|
| 1070 |
+
inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
|
| 1071 |
+
|
| 1072 |
+
if cache_position is None:
|
| 1073 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1074 |
+
cache_position = torch.arange(
|
| 1075 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 1076 |
+
)
|
| 1077 |
+
if position_ids is None:
|
| 1078 |
+
position_ids = cache_position.unsqueeze(0)
|
| 1079 |
+
|
| 1080 |
+
causal_mask = self._update_causal_mask(
|
| 1081 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
hidden_states = inputs_embeds
|
| 1085 |
+
|
| 1086 |
+
# create position embeddings to be shared across the decoder layers
|
| 1087 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1088 |
+
|
| 1089 |
+
# decoder layers
|
| 1090 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1091 |
+
all_self_attns = () if output_attentions else None
|
| 1092 |
+
all_router_loss = torch.tensor(0.0, device=inputs_embeds.device) if self.config.use_moe else None
|
| 1093 |
+
all_gate_logits = ()
|
| 1094 |
+
|
| 1095 |
+
for decoder_layer in self.layers:
|
| 1096 |
+
if output_hidden_states:
|
| 1097 |
+
all_hidden_states += (hidden_states,)
|
| 1098 |
+
|
| 1099 |
+
if self.gradient_checkpointing and self.training:
|
| 1100 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1101 |
+
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
| 1102 |
+
hidden_states,
|
| 1103 |
+
causal_mask,
|
| 1104 |
+
position_ids,
|
| 1105 |
+
past_key_values,
|
| 1106 |
+
output_attentions,
|
| 1107 |
+
use_cache,
|
| 1108 |
+
cache_position,
|
| 1109 |
+
position_embeddings,
|
| 1110 |
+
)
|
| 1111 |
+
else:
|
| 1112 |
+
layer_outputs = decoder_layer(
|
| 1113 |
+
hidden_states,
|
| 1114 |
+
causal_mask,
|
| 1115 |
+
position_ids,
|
| 1116 |
+
past_key_values,
|
| 1117 |
+
output_attentions,
|
| 1118 |
+
use_cache,
|
| 1119 |
+
cache_position,
|
| 1120 |
+
position_embeddings,
|
| 1121 |
+
**flash_attn_kwargs,
|
| 1122 |
+
)
|
| 1123 |
+
|
| 1124 |
+
hidden_states = layer_outputs[0]
|
| 1125 |
+
|
| 1126 |
+
if output_attentions:
|
| 1127 |
+
all_self_attns += (layer_outputs[1],)
|
| 1128 |
+
|
| 1129 |
+
if self.config.use_moe:
|
| 1130 |
+
layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1]
|
| 1131 |
+
all_gate_logits = all_gate_logits + (gate_logits,)
|
| 1132 |
+
|
| 1133 |
+
hidden_states = self.norm(hidden_states)
|
| 1134 |
+
|
| 1135 |
+
# add hidden states from the last decoder layer
|
| 1136 |
+
if output_hidden_states:
|
| 1137 |
+
all_hidden_states += (hidden_states,)
|
| 1138 |
+
|
| 1139 |
+
# assert all_router_loss is None, f'moe not support `return-dict`'
|
| 1140 |
+
return Erine4_5_MoeModelOutputWithPast(
|
| 1141 |
+
last_hidden_state=hidden_states,
|
| 1142 |
+
past_key_values=past_key_values,
|
| 1143 |
+
hidden_states=all_hidden_states,
|
| 1144 |
+
attentions=all_self_attns,
|
| 1145 |
+
router_loss=all_router_loss,
|
| 1146 |
+
gate_logits=all_gate_logits,
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
def _update_causal_mask(
|
| 1150 |
+
self,
|
| 1151 |
+
attention_mask: Union[torch.Tensor, "BlockMask"],
|
| 1152 |
+
input_tensor: torch.Tensor,
|
| 1153 |
+
cache_position: torch.Tensor,
|
| 1154 |
+
past_key_values: Cache,
|
| 1155 |
+
output_attentions: bool = False,
|
| 1156 |
+
):
|
| 1157 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 1158 |
+
if attention_mask is not None and past_key_values is not None:
|
| 1159 |
+
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
| 1160 |
+
if is_padding_right:
|
| 1161 |
+
raise ValueError(
|
| 1162 |
+
"You are attempting to perform batched generation with padding_side='right'"
|
| 1163 |
+
" this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
|
| 1164 |
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 1165 |
+
)
|
| 1166 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1167 |
+
return attention_mask
|
| 1168 |
+
return None
|
| 1169 |
+
if self.config._attn_implementation == "flex_attention":
|
| 1170 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 1171 |
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 1172 |
+
return attention_mask
|
| 1173 |
+
|
| 1174 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 1175 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 1176 |
+
# to infer the attention mask.
|
| 1177 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1178 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1179 |
+
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
| 1180 |
+
|
| 1181 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1182 |
+
if (
|
| 1183 |
+
self.config._attn_implementation == "sdpa"
|
| 1184 |
+
and not (using_static_cache or using_sliding_window_cache)
|
| 1185 |
+
and not output_attentions
|
| 1186 |
+
):
|
| 1187 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 1188 |
+
attention_mask,
|
| 1189 |
+
inputs_embeds=input_tensor,
|
| 1190 |
+
past_key_values_length=past_seen_tokens,
|
| 1191 |
+
sliding_window=self.config.sliding_window,
|
| 1192 |
+
is_training=self.training,
|
| 1193 |
+
):
|
| 1194 |
+
return None
|
| 1195 |
+
|
| 1196 |
+
dtype = input_tensor.dtype
|
| 1197 |
+
min_dtype = torch.finfo(dtype).min
|
| 1198 |
+
sequence_length = input_tensor.shape[1]
|
| 1199 |
+
# SlidingWindowCache or StaticCache
|
| 1200 |
+
if using_sliding_window_cache or using_static_cache:
|
| 1201 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 1202 |
+
# DynamicCache or no cache
|
| 1203 |
+
else:
|
| 1204 |
+
target_length = (
|
| 1205 |
+
attention_mask.shape[-1]
|
| 1206 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 1207 |
+
else past_seen_tokens + sequence_length + 1
|
| 1208 |
+
)
|
| 1209 |
+
|
| 1210 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 1211 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 1212 |
+
attention_mask,
|
| 1213 |
+
sequence_length=sequence_length,
|
| 1214 |
+
target_length=target_length,
|
| 1215 |
+
dtype=dtype,
|
| 1216 |
+
cache_position=cache_position,
|
| 1217 |
+
batch_size=input_tensor.shape[0],
|
| 1218 |
+
config=self.config,
|
| 1219 |
+
past_key_values=past_key_values,
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
if (
|
| 1223 |
+
self.config._attn_implementation == "sdpa"
|
| 1224 |
+
and attention_mask is not None
|
| 1225 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 1226 |
+
and not output_attentions
|
| 1227 |
+
):
|
| 1228 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 1229 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 1230 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 1231 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 1232 |
+
|
| 1233 |
+
return causal_mask
|
| 1234 |
+
|
| 1235 |
+
@staticmethod
|
| 1236 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1237 |
+
attention_mask: torch.Tensor,
|
| 1238 |
+
sequence_length: int,
|
| 1239 |
+
target_length: int,
|
| 1240 |
+
dtype: torch.dtype,
|
| 1241 |
+
cache_position: torch.Tensor,
|
| 1242 |
+
batch_size: int,
|
| 1243 |
+
config: Ernie4_5_MoeConfig,
|
| 1244 |
+
past_key_values: Cache,
|
| 1245 |
+
):
|
| 1246 |
+
"""
|
| 1247 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 1248 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 1249 |
+
|
| 1250 |
+
Args:
|
| 1251 |
+
attention_mask (`torch.Tensor`):
|
| 1252 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
| 1253 |
+
sequence_length (`int`):
|
| 1254 |
+
The sequence length being processed.
|
| 1255 |
+
target_length (`int`):
|
| 1256 |
+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
| 1257 |
+
dtype (`torch.dtype`):
|
| 1258 |
+
The dtype to use for the 4D attention mask.
|
| 1259 |
+
cache_position (`torch.Tensor`):
|
| 1260 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 1261 |
+
batch_size (`torch.Tensor`):
|
| 1262 |
+
Batch size.
|
| 1263 |
+
config (`Ernie4_5_MoeConfig`):
|
| 1264 |
+
The model's configuration class
|
| 1265 |
+
past_key_values (`Cache`):
|
| 1266 |
+
The cache class that is being used currently to generate
|
| 1267 |
+
"""
|
| 1268 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 1269 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 1270 |
+
causal_mask = attention_mask
|
| 1271 |
+
else:
|
| 1272 |
+
min_dtype = torch.finfo(dtype).min
|
| 1273 |
+
causal_mask = torch.full(
|
| 1274 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
| 1275 |
+
)
|
| 1276 |
+
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
| 1277 |
+
-1, 1
|
| 1278 |
+
)
|
| 1279 |
+
text_config = config.get_text_config()
|
| 1280 |
+
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
| 1281 |
+
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
| 1282 |
+
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
| 1283 |
+
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
| 1284 |
+
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
| 1285 |
+
cache_position.reshape(-1, 1) - text_config.sliding_window
|
| 1286 |
+
)
|
| 1287 |
+
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
| 1288 |
+
causal_mask *= diagonal_attend_mask
|
| 1289 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 1290 |
+
if attention_mask is not None:
|
| 1291 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 1292 |
+
if attention_mask.shape[-1] > target_length:
|
| 1293 |
+
attention_mask = attention_mask[:, :target_length]
|
| 1294 |
+
mask_length = attention_mask.shape[-1]
|
| 1295 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 1296 |
+
causal_mask.device
|
| 1297 |
+
)
|
| 1298 |
+
padding_mask = padding_mask == 0
|
| 1299 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 1300 |
+
padding_mask, min_dtype
|
| 1301 |
+
)
|
| 1302 |
+
return causal_mask
|
| 1303 |
+
|
| 1304 |
+
@auto_docstring
|
| 1305 |
+
class Ernie4_5_MoeForCausalLM(Ernie4_5_PretrainedModel,GenerationMixin):
|
| 1306 |
+
"""ERNIE Mixture of Experts (MoE) model for causal language modeling."""
|
| 1307 |
+
|
| 1308 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1309 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 1310 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 1311 |
+
|
| 1312 |
+
def __init__(self, config):
|
| 1313 |
+
"""
|
| 1314 |
+
Initializes the ERNIE MoE model for causal language modeling.
|
| 1315 |
+
|
| 1316 |
+
Args:
|
| 1317 |
+
config (dict): Model configuration.
|
| 1318 |
+
"""
|
| 1319 |
+
super().__init__(config)
|
| 1320 |
+
self.config = config
|
| 1321 |
+
self.model = Ernie4_5_Model(config)
|
| 1322 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size,bias=config.weight_share_add_bias and config.use_bias) # TODO
|
| 1323 |
+
self.loss_function = ErniePretrainingCriterion(config)
|
| 1324 |
+
|
| 1325 |
+
# Initialize weights and apply final processing
|
| 1326 |
+
self.post_init()
|
| 1327 |
+
|
| 1328 |
+
def get_input_embeddings(self):
|
| 1329 |
+
"""Returns the input embeddings layer."""
|
| 1330 |
+
return self.model.embed_tokens
|
| 1331 |
+
|
| 1332 |
+
def set_input_embeddings(self, value):
|
| 1333 |
+
"""Sets the input embeddings layer."""
|
| 1334 |
+
self.ernie.embed_tokens = value
|
| 1335 |
+
|
| 1336 |
+
def get_output_embeddings(self):
|
| 1337 |
+
"""Returns the output embeddings (LM head)."""
|
| 1338 |
+
return self.lm_head
|
| 1339 |
+
|
| 1340 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1341 |
+
"""Sets the output embeddings layer."""
|
| 1342 |
+
self.lm_head = new_embeddings
|
| 1343 |
+
|
| 1344 |
+
def set_decoder(self, decoder):
|
| 1345 |
+
"""Sets the ERNIE decoder model."""
|
| 1346 |
+
self.model = decoder
|
| 1347 |
+
|
| 1348 |
+
def get_decoder(self):
|
| 1349 |
+
"""Get the transformer decoder."""
|
| 1350 |
+
return self.model
|
| 1351 |
+
|
| 1352 |
+
@can_return_tuple
|
| 1353 |
+
def forward(
|
| 1354 |
+
self,
|
| 1355 |
+
input_ids,
|
| 1356 |
+
attention_mask=None,
|
| 1357 |
+
position_ids=None,
|
| 1358 |
+
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
| 1359 |
+
inputs_embeds=None,
|
| 1360 |
+
labels=None,
|
| 1361 |
+
loss_mask=None,
|
| 1362 |
+
use_cache=False,
|
| 1363 |
+
output_attentions: Optional[bool] = None,
|
| 1364 |
+
output_hidden_states: Optional[bool] = None,
|
| 1365 |
+
**kwargs: Unpack[KwargsForCausalLM],
|
| 1366 |
+
):
|
| 1367 |
+
"""
|
| 1368 |
+
Forward pass for causal language modeling.
|
| 1369 |
+
"""
|
| 1370 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1371 |
+
output_hidden_states = (
|
| 1372 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1373 |
+
)
|
| 1374 |
+
|
| 1375 |
+
outputs = self.model(
|
| 1376 |
+
input_ids,
|
| 1377 |
+
position_ids=position_ids,
|
| 1378 |
+
attention_mask=attention_mask,
|
| 1379 |
+
inputs_embeds=inputs_embeds,
|
| 1380 |
+
use_cache=use_cache,
|
| 1381 |
+
past_key_values=past_key_values,
|
| 1382 |
+
output_attentions=output_attentions,
|
| 1383 |
+
output_hidden_states=output_hidden_states,
|
| 1384 |
+
**kwargs,
|
| 1385 |
+
)
|
| 1386 |
+
|
| 1387 |
+
hidden_states = outputs.last_hidden_state
|
| 1388 |
+
logits = self.lm_head(hidden_states)
|
| 1389 |
+
|
| 1390 |
+
loss, router_loss = None, None
|
| 1391 |
+
if getattr(self.config, "use_moe", False):
|
| 1392 |
+
router_loss = outputs.router_loss
|
| 1393 |
+
|
| 1394 |
+
if labels is not None:
|
| 1395 |
+
loss, _ = self.loss_function(logits, labels, loss_mask, router_loss)
|
| 1396 |
+
|
| 1397 |
+
return Ernie4_5_MoeCausalLMOutputWithPast(
|
| 1398 |
+
loss=loss,
|
| 1399 |
+
logits=logits,
|
| 1400 |
+
past_key_values=outputs.past_key_values,
|
| 1401 |
+
hidden_states=outputs.hidden_states,
|
| 1402 |
+
attentions=outputs.attentions,
|
| 1403 |
+
router_loss=router_loss,
|
| 1404 |
+
)
|
| 1405 |
+
|
| 1406 |
+
|
| 1407 |
+
|
| 1408 |
+
__all__ = [
|
| 1409 |
+
"Ernie4_5_Model",
|
| 1410 |
+
"Ernie4_5_MoeForCausalLM",
|
| 1411 |
+
"Ernie4_5_PretrainedModel"
|
| 1412 |
+
]
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<unk>", "unk_token": "<unk>", "cls_token": "<|begin_of_sentence|>", "sep_token": "<|end_of_sentence|>", "mask_token": "<mask:1>", "sys_start_token": "<mask:4>", "sys_end_token": "<mask:5>", "header_start_token": "<mask:6>", "header_end_token": "<mask:7>", "additional_special_tokens": ["<|IMAGE_PLACEHOLDER|>", "<|AUDIO_PLACEHOLDER|>", "<|LOC_0|>", "<|LOC_1|>", "<|LOC_2|>", "<|LOC_3|>", "<|LOC_4|>", "<|LOC_5|>", "<|LOC_6|>", "<|LOC_7|>", "<|LOC_8|>", "<|LOC_9|>", "<|LOC_10|>", "<|LOC_11|>", "<|LOC_12|>", "<|LOC_13|>", "<|LOC_14|>", "<|LOC_15|>", "<|LOC_16|>", "<|LOC_17|>", "<|LOC_18|>", "<|LOC_19|>", "<|LOC_20|>", "<|LOC_21|>", "<|LOC_22|>", "<|LOC_23|>", "<|LOC_24|>", "<|LOC_25|>", "<|LOC_26|>", "<|LOC_27|>", "<|LOC_28|>", "<|LOC_29|>", "<|LOC_30|>", "<|LOC_31|>", "<|LOC_32|>", "<|LOC_33|>", "<|LOC_34|>", "<|LOC_35|>", "<|LOC_36|>", "<|LOC_37|>", "<|LOC_38|>", "<|LOC_39|>", "<|LOC_40|>", "<|LOC_41|>", "<|LOC_42|>", "<|LOC_43|>", "<|LOC_44|>", "<|LOC_45|>", "<|LOC_46|>", "<|LOC_47|>", "<|LOC_48|>", "<|LOC_49|>", "<|LOC_50|>", "<|LOC_51|>", "<|LOC_52|>", "<|LOC_53|>", "<|LOC_54|>", "<|LOC_55|>", "<|LOC_56|>", "<|LOC_57|>", "<|LOC_58|>", "<|LOC_59|>", "<|LOC_60|>", "<|LOC_61|>", "<|LOC_62|>", "<|LOC_63|>", "<|LOC_64|>", "<|LOC_65|>", "<|LOC_66|>", "<|LOC_67|>", "<|LOC_68|>", "<|LOC_69|>", "<|LOC_70|>", "<|LOC_71|>", "<|LOC_72|>", "<|LOC_73|>", "<|LOC_74|>", "<|LOC_75|>", "<|LOC_76|>", "<|LOC_77|>", "<|LOC_78|>", "<|LOC_79|>", "<|LOC_80|>", "<|LOC_81|>", "<|LOC_82|>", "<|LOC_83|>", "<|LOC_84|>", "<|LOC_85|>", "<|LOC_86|>", "<|LOC_87|>", "<|LOC_88|>", "<|LOC_89|>", "<|LOC_90|>", "<|LOC_91|>", "<|LOC_92|>", "<|LOC_93|>", "<|LOC_94|>", "<|LOC_95|>", "<|LOC_96|>", "<|LOC_97|>", "<|LOC_98|>", "<|LOC_99|>", "<|LOC_100|>", "<|LOC_101|>", "<|LOC_102|>", "<|LOC_103|>", "<|LOC_104|>", "<|LOC_105|>", "<|LOC_106|>", "<|LOC_107|>", "<|LOC_108|>", "<|LOC_109|>", "<|LOC_110|>", "<|LOC_111|>", "<|LOC_112|>", "<|LOC_113|>", "<|LOC_114|>", "<|LOC_115|>", "<|LOC_116|>", "<|LOC_117|>", "<|LOC_118|>", "<|LOC_119|>", "<|LOC_120|>", "<|LOC_121|>", "<|LOC_122|>", "<|LOC_123|>", "<|LOC_124|>", "<|LOC_125|>", "<|LOC_126|>", "<|LOC_127|>", "<|LOC_128|>", "<|LOC_129|>", "<|LOC_130|>", "<|LOC_131|>", "<|LOC_132|>", "<|LOC_133|>", "<|LOC_134|>", "<|LOC_135|>", "<|LOC_136|>", "<|LOC_137|>", "<|LOC_138|>", "<|LOC_139|>", "<|LOC_140|>", "<|LOC_141|>", "<|LOC_142|>", "<|LOC_143|>", "<|LOC_144|>", "<|LOC_145|>", "<|LOC_146|>", "<|LOC_147|>", "<|LOC_148|>", "<|LOC_149|>", "<|LOC_150|>", "<|LOC_151|>", "<|LOC_152|>", "<|LOC_153|>", "<|LOC_154|>", "<|LOC_155|>", "<|LOC_156|>", "<|LOC_157|>", "<|LOC_158|>", "<|LOC_159|>", "<|LOC_160|>", "<|LOC_161|>", "<|LOC_162|>", "<|LOC_163|>", "<|LOC_164|>", "<|LOC_165|>", "<|LOC_166|>", "<|LOC_167|>", "<|LOC_168|>", "<|LOC_169|>", "<|LOC_170|>", "<|LOC_171|>", "<|LOC_172|>", "<|LOC_173|>", "<|LOC_174|>", "<|LOC_175|>", "<|LOC_176|>", "<|LOC_177|>", "<|LOC_178|>", "<|LOC_179|>", "<|LOC_180|>", "<|LOC_181|>", "<|LOC_182|>", "<|LOC_183|>", "<|LOC_184|>", "<|LOC_185|>", "<|LOC_186|>", "<|LOC_187|>", "<|LOC_188|>", "<|LOC_189|>", "<|LOC_190|>", "<|LOC_191|>", "<|LOC_192|>", "<|LOC_193|>", "<|LOC_194|>", "<|LOC_195|>", "<|LOC_196|>", "<|LOC_197|>", "<|LOC_198|>", "<|LOC_199|>", "<|LOC_200|>", "<|LOC_201|>", "<|LOC_202|>", "<|LOC_203|>", "<|LOC_204|>", "<|LOC_205|>", "<|LOC_206|>", "<|LOC_207|>", "<|LOC_208|>", "<|LOC_209|>", "<|LOC_210|>", "<|LOC_211|>", "<|LOC_212|>", "<|LOC_213|>", "<|LOC_214|>", "<|LOC_215|>", "<|LOC_216|>", "<|LOC_217|>", "<|LOC_218|>", "<|LOC_219|>", "<|LOC_220|>", "<|LOC_221|>", "<|LOC_222|>", "<|LOC_223|>", "<|LOC_224|>", "<|LOC_225|>", "<|LOC_226|>", "<|LOC_227|>", "<|LOC_228|>", "<|LOC_229|>", "<|LOC_230|>", "<|LOC_231|>", "<|LOC_232|>", "<|LOC_233|>", "<|LOC_234|>", "<|LOC_235|>", "<|LOC_236|>", "<|LOC_237|>", "<|LOC_238|>", "<|LOC_239|>", "<|LOC_240|>", "<|LOC_241|>", "<|LOC_242|>", "<|LOC_243|>", "<|LOC_244|>", "<|LOC_245|>", "<|LOC_246|>", "<|LOC_247|>", "<|LOC_248|>", "<|LOC_249|>", "<|LOC_250|>", "<|LOC_251|>", "<|LOC_252|>", "<|LOC_253|>", "<|LOC_254|>", "<|LOC_255|>", "<|LOC_256|>", "<|LOC_257|>", "<|LOC_258|>", "<|LOC_259|>", "<|LOC_260|>", "<|LOC_261|>", "<|LOC_262|>", "<|LOC_263|>", "<|LOC_264|>", "<|LOC_265|>", "<|LOC_266|>", "<|LOC_267|>", "<|LOC_268|>", "<|LOC_269|>", "<|LOC_270|>", "<|LOC_271|>", "<|LOC_272|>", "<|LOC_273|>", "<|LOC_274|>", "<|LOC_275|>", "<|LOC_276|>", "<|LOC_277|>", "<|LOC_278|>", "<|LOC_279|>", "<|LOC_280|>", "<|LOC_281|>", "<|LOC_282|>", "<|LOC_283|>", "<|LOC_284|>", "<|LOC_285|>", "<|LOC_286|>", "<|LOC_287|>", "<|LOC_288|>", "<|LOC_289|>", "<|LOC_290|>", "<|LOC_291|>", "<|LOC_292|>", "<|LOC_293|>", "<|LOC_294|>", "<|LOC_295|>", "<|LOC_296|>", "<|LOC_297|>", "<|LOC_298|>", "<|LOC_299|>", "<|LOC_300|>", "<|LOC_301|>", "<|LOC_302|>", "<|LOC_303|>", "<|LOC_304|>", "<|LOC_305|>", "<|LOC_306|>", "<|LOC_307|>", "<|LOC_308|>", "<|LOC_309|>", "<|LOC_310|>", "<|LOC_311|>", "<|LOC_312|>", "<|LOC_313|>", "<|LOC_314|>", "<|LOC_315|>", "<|LOC_316|>", "<|LOC_317|>", "<|LOC_318|>", "<|LOC_319|>", "<|LOC_320|>", "<|LOC_321|>", "<|LOC_322|>", "<|LOC_323|>", "<|LOC_324|>", "<|LOC_325|>", "<|LOC_326|>", "<|LOC_327|>", "<|LOC_328|>", "<|LOC_329|>", "<|LOC_330|>", "<|LOC_331|>", "<|LOC_332|>", "<|LOC_333|>", "<|LOC_334|>", "<|LOC_335|>", "<|LOC_336|>", "<|LOC_337|>", "<|LOC_338|>", "<|LOC_339|>", "<|LOC_340|>", "<|LOC_341|>", "<|LOC_342|>", "<|LOC_343|>", "<|LOC_344|>", "<|LOC_345|>", "<|LOC_346|>", "<|LOC_347|>", "<|LOC_348|>", "<|LOC_349|>", "<|LOC_350|>", "<|LOC_351|>", "<|LOC_352|>", "<|LOC_353|>", "<|LOC_354|>", "<|LOC_355|>", "<|LOC_356|>", "<|LOC_357|>", "<|LOC_358|>", "<|LOC_359|>", "<|LOC_360|>", "<|LOC_361|>", "<|LOC_362|>", "<|LOC_363|>", "<|LOC_364|>", "<|LOC_365|>", "<|LOC_366|>", "<|LOC_367|>", "<|LOC_368|>", "<|LOC_369|>", "<|LOC_370|>", "<|LOC_371|>", "<|LOC_372|>", "<|LOC_373|>", "<|LOC_374|>", "<|LOC_375|>", "<|LOC_376|>", "<|LOC_377|>", "<|LOC_378|>", "<|LOC_379|>", "<|LOC_380|>", "<|LOC_381|>", "<|LOC_382|>", "<|LOC_383|>", "<|LOC_384|>", "<|LOC_385|>", "<|LOC_386|>", "<|LOC_387|>", "<|LOC_388|>", "<|LOC_389|>", "<|LOC_390|>", "<|LOC_391|>", "<|LOC_392|>", "<|LOC_393|>", "<|LOC_394|>", "<|LOC_395|>", "<|LOC_396|>", "<|LOC_397|>", "<|LOC_398|>", "<|LOC_399|>", "<|LOC_400|>", "<|LOC_401|>", "<|LOC_402|>", "<|LOC_403|>", "<|LOC_404|>", "<|LOC_405|>", "<|LOC_406|>", "<|LOC_407|>", "<|LOC_408|>", "<|LOC_409|>", "<|LOC_410|>", "<|LOC_411|>", "<|LOC_412|>", "<|LOC_413|>", "<|LOC_414|>", "<|LOC_415|>", "<|LOC_416|>", "<|LOC_417|>", "<|LOC_418|>", "<|LOC_419|>", "<|LOC_420|>", "<|LOC_421|>", "<|LOC_422|>", "<|LOC_423|>", "<|LOC_424|>", "<|LOC_425|>", "<|LOC_426|>", "<|LOC_427|>", "<|LOC_428|>", "<|LOC_429|>", "<|LOC_430|>", "<|LOC_431|>", "<|LOC_432|>", "<|LOC_433|>", "<|LOC_434|>", "<|LOC_435|>", "<|LOC_436|>", "<|LOC_437|>", "<|LOC_438|>", "<|LOC_439|>", "<|LOC_440|>", "<|LOC_441|>", "<|LOC_442|>", "<|LOC_443|>", "<|LOC_444|>", "<|LOC_445|>", "<|LOC_446|>", "<|LOC_447|>", "<|LOC_448|>", "<|LOC_449|>", "<|LOC_450|>", "<|LOC_451|>", "<|LOC_452|>", "<|LOC_453|>", "<|LOC_454|>", "<|LOC_455|>", "<|LOC_456|>", "<|LOC_457|>", "<|LOC_458|>", "<|LOC_459|>", "<|LOC_460|>", "<|LOC_461|>", "<|LOC_462|>", "<|LOC_463|>", "<|LOC_464|>", "<|LOC_465|>", "<|LOC_466|>", "<|LOC_467|>", "<|LOC_468|>", "<|LOC_469|>", "<|LOC_470|>", "<|LOC_471|>", "<|LOC_472|>", "<|LOC_473|>", "<|LOC_474|>", "<|LOC_475|>", "<|LOC_476|>", "<|LOC_477|>", "<|LOC_478|>", "<|LOC_479|>", "<|LOC_480|>", "<|LOC_481|>", "<|LOC_482|>", "<|LOC_483|>", "<|LOC_484|>", "<|LOC_485|>", "<|LOC_486|>", "<|LOC_487|>", "<|LOC_488|>", "<|LOC_489|>", "<|LOC_490|>", "<|LOC_491|>", "<|LOC_492|>", "<|LOC_493|>", "<|LOC_494|>", "<|LOC_495|>", "<|LOC_496|>", "<|LOC_497|>", "<|LOC_498|>", "<|LOC_499|>", "<|LOC_500|>", "<|LOC_501|>", "<|LOC_502|>", "<|LOC_503|>", "<|LOC_504|>", "<|LOC_505|>", "<|LOC_506|>", "<|LOC_507|>", "<|LOC_508|>", "<|LOC_509|>", "<|LOC_510|>", "<|LOC_511|>", "<|LOC_512|>", "<|LOC_513|>", "<|LOC_514|>", "<|LOC_515|>", "<|LOC_516|>", "<|LOC_517|>", "<|LOC_518|>", "<|LOC_519|>", "<|LOC_520|>", "<|LOC_521|>", "<|LOC_522|>", "<|LOC_523|>", "<|LOC_524|>", "<|LOC_525|>", "<|LOC_526|>", "<|LOC_527|>", "<|LOC_528|>", "<|LOC_529|>", "<|LOC_530|>", "<|LOC_531|>", "<|LOC_532|>", "<|LOC_533|>", "<|LOC_534|>", "<|LOC_535|>", "<|LOC_536|>", "<|LOC_537|>", "<|LOC_538|>", "<|LOC_539|>", "<|LOC_540|>", "<|LOC_541|>", "<|LOC_542|>", "<|LOC_543|>", "<|LOC_544|>", "<|LOC_545|>", "<|LOC_546|>", "<|LOC_547|>", "<|LOC_548|>", "<|LOC_549|>", "<|LOC_550|>", "<|LOC_551|>", "<|LOC_552|>", "<|LOC_553|>", "<|LOC_554|>", "<|LOC_555|>", "<|LOC_556|>", "<|LOC_557|>", "<|LOC_558|>", "<|LOC_559|>", "<|LOC_560|>", "<|LOC_561|>", "<|LOC_562|>", "<|LOC_563|>", "<|LOC_564|>", "<|LOC_565|>", "<|LOC_566|>", "<|LOC_567|>", "<|LOC_568|>", "<|LOC_569|>", "<|LOC_570|>", "<|LOC_571|>", "<|LOC_572|>", "<|LOC_573|>", "<|LOC_574|>", "<|LOC_575|>", "<|LOC_576|>", "<|LOC_577|>", "<|LOC_578|>", "<|LOC_579|>", "<|LOC_580|>", "<|LOC_581|>", "<|LOC_582|>", "<|LOC_583|>", "<|LOC_584|>", "<|LOC_585|>", "<|LOC_586|>", "<|LOC_587|>", "<|LOC_588|>", "<|LOC_589|>", "<|LOC_590|>", "<|LOC_591|>", "<|LOC_592|>", "<|LOC_593|>", "<|LOC_594|>", "<|LOC_595|>", "<|LOC_596|>", "<|LOC_597|>", "<|LOC_598|>", "<|LOC_599|>", "<|LOC_600|>", "<|LOC_601|>", "<|LOC_602|>", "<|LOC_603|>", "<|LOC_604|>", "<|LOC_605|>", "<|LOC_606|>", "<|LOC_607|>", "<|LOC_608|>", "<|LOC_609|>", "<|LOC_610|>", "<|LOC_611|>", "<|LOC_612|>", "<|LOC_613|>", "<|LOC_614|>", "<|LOC_615|>", "<|LOC_616|>", "<|LOC_617|>", "<|LOC_618|>", "<|LOC_619|>", "<|LOC_620|>", "<|LOC_621|>", "<|LOC_622|>", "<|LOC_623|>", "<|LOC_624|>", "<|LOC_625|>", "<|LOC_626|>", "<|LOC_627|>", "<|LOC_628|>", "<|LOC_629|>", "<|LOC_630|>", "<|LOC_631|>", "<|LOC_632|>", "<|LOC_633|>", "<|LOC_634|>", "<|LOC_635|>", "<|LOC_636|>", "<|LOC_637|>", "<|LOC_638|>", "<|LOC_639|>", "<|LOC_640|>", "<|LOC_641|>", "<|LOC_642|>", "<|LOC_643|>", "<|LOC_644|>", "<|LOC_645|>", "<|LOC_646|>", "<|LOC_647|>", "<|LOC_648|>", "<|LOC_649|>", "<|LOC_650|>", "<|LOC_651|>", "<|LOC_652|>", "<|LOC_653|>", "<|LOC_654|>", "<|LOC_655|>", "<|LOC_656|>", "<|LOC_657|>", "<|LOC_658|>", "<|LOC_659|>", "<|LOC_660|>", "<|LOC_661|>", "<|LOC_662|>", "<|LOC_663|>", "<|LOC_664|>", "<|LOC_665|>", "<|LOC_666|>", "<|LOC_667|>", "<|LOC_668|>", "<|LOC_669|>", "<|LOC_670|>", "<|LOC_671|>", "<|LOC_672|>", "<|LOC_673|>", "<|LOC_674|>", "<|LOC_675|>", "<|LOC_676|>", "<|LOC_677|>", "<|LOC_678|>", "<|LOC_679|>", "<|LOC_680|>", "<|LOC_681|>", "<|LOC_682|>", "<|LOC_683|>", "<|LOC_684|>", "<|LOC_685|>", "<|LOC_686|>", "<|LOC_687|>", "<|LOC_688|>", "<|LOC_689|>", "<|LOC_690|>", "<|LOC_691|>", "<|LOC_692|>", "<|LOC_693|>", "<|LOC_694|>", "<|LOC_695|>", "<|LOC_696|>", "<|LOC_697|>", "<|LOC_698|>", "<|LOC_699|>", "<|LOC_700|>", "<|LOC_701|>", "<|LOC_702|>", "<|LOC_703|>", "<|LOC_704|>", "<|LOC_705|>", "<|LOC_706|>", "<|LOC_707|>", "<|LOC_708|>", "<|LOC_709|>", "<|LOC_710|>", "<|LOC_711|>", "<|LOC_712|>", "<|LOC_713|>", "<|LOC_714|>", "<|LOC_715|>", "<|LOC_716|>", "<|LOC_717|>", "<|LOC_718|>", "<|LOC_719|>", "<|LOC_720|>", "<|LOC_721|>", "<|LOC_722|>", "<|LOC_723|>", "<|LOC_724|>", "<|LOC_725|>", "<|LOC_726|>", "<|LOC_727|>", "<|LOC_728|>", "<|LOC_729|>", "<|LOC_730|>", "<|LOC_731|>", "<|LOC_732|>", "<|LOC_733|>", "<|LOC_734|>", "<|LOC_735|>", "<|LOC_736|>", "<|LOC_737|>", "<|LOC_738|>", "<|LOC_739|>", "<|LOC_740|>", "<|LOC_741|>", "<|LOC_742|>", "<|LOC_743|>", "<|LOC_744|>", "<|LOC_745|>", "<|LOC_746|>", "<|LOC_747|>", "<|LOC_748|>", "<|LOC_749|>", "<|LOC_750|>", "<|LOC_751|>", "<|LOC_752|>", "<|LOC_753|>", "<|LOC_754|>", "<|LOC_755|>", "<|LOC_756|>", "<|LOC_757|>", "<|LOC_758|>", "<|LOC_759|>", "<|LOC_760|>", "<|LOC_761|>", "<|LOC_762|>", "<|LOC_763|>", "<|LOC_764|>", "<|LOC_765|>", "<|LOC_766|>", "<|LOC_767|>", "<|LOC_768|>", "<|LOC_769|>", "<|LOC_770|>", "<|LOC_771|>", "<|LOC_772|>", "<|LOC_773|>", "<|LOC_774|>", "<|LOC_775|>", "<|LOC_776|>", "<|LOC_777|>", "<|LOC_778|>", "<|LOC_779|>", "<|LOC_780|>", "<|LOC_781|>", "<|LOC_782|>", "<|LOC_783|>", "<|LOC_784|>", "<|LOC_785|>", "<|LOC_786|>", "<|LOC_787|>", "<|LOC_788|>", "<|LOC_789|>", "<|LOC_790|>", "<|LOC_791|>", "<|LOC_792|>", "<|LOC_793|>", "<|LOC_794|>", "<|LOC_795|>", "<|LOC_796|>", "<|LOC_797|>", "<|LOC_798|>", "<|LOC_799|>", "<|LOC_800|>", "<|LOC_801|>", "<|LOC_802|>", "<|LOC_803|>", "<|LOC_804|>", "<|LOC_805|>", "<|LOC_806|>", "<|LOC_807|>", "<|LOC_808|>", "<|LOC_809|>", "<|LOC_810|>", "<|LOC_811|>", "<|LOC_812|>", "<|LOC_813|>", "<|LOC_814|>", "<|LOC_815|>", "<|LOC_816|>", "<|LOC_817|>", "<|LOC_818|>", "<|LOC_819|>", "<|LOC_820|>", "<|LOC_821|>", "<|LOC_822|>", "<|LOC_823|>", "<|LOC_824|>", "<|LOC_825|>", "<|LOC_826|>", "<|LOC_827|>", "<|LOC_828|>", "<|LOC_829|>", "<|LOC_830|>", "<|LOC_831|>", "<|LOC_832|>", "<|LOC_833|>", "<|LOC_834|>", "<|LOC_835|>", "<|LOC_836|>", "<|LOC_837|>", "<|LOC_838|>", "<|LOC_839|>", "<|LOC_840|>", "<|LOC_841|>", "<|LOC_842|>", "<|LOC_843|>", "<|LOC_844|>", "<|LOC_845|>", "<|LOC_846|>", "<|LOC_847|>", "<|LOC_848|>", "<|LOC_849|>", "<|LOC_850|>", "<|LOC_851|>", "<|LOC_852|>", "<|LOC_853|>", "<|LOC_854|>", "<|LOC_855|>", "<|LOC_856|>", "<|LOC_857|>", "<|LOC_858|>", "<|LOC_859|>", "<|LOC_860|>", "<|LOC_861|>", "<|LOC_862|>", "<|LOC_863|>", "<|LOC_864|>", "<|LOC_865|>", "<|LOC_866|>", "<|LOC_867|>", "<|LOC_868|>", "<|LOC_869|>", "<|LOC_870|>", "<|LOC_871|>", "<|LOC_872|>", "<|LOC_873|>", "<|LOC_874|>", "<|LOC_875|>", "<|LOC_876|>", "<|LOC_877|>", "<|LOC_878|>", "<|LOC_879|>", "<|LOC_880|>", "<|LOC_881|>", "<|LOC_882|>", "<|LOC_883|>", "<|LOC_884|>", "<|LOC_885|>", "<|LOC_886|>", "<|LOC_887|>", "<|LOC_888|>", "<|LOC_889|>", "<|LOC_890|>", "<|LOC_891|>", "<|LOC_892|>", "<|LOC_893|>", "<|LOC_894|>", "<|LOC_895|>", "<|LOC_896|>", "<|LOC_897|>", "<|LOC_898|>", "<|LOC_899|>", "<|LOC_900|>", "<|LOC_901|>", "<|LOC_902|>", "<|LOC_903|>", "<|LOC_904|>", "<|LOC_905|>", "<|LOC_906|>", "<|LOC_907|>", "<|LOC_908|>", "<|LOC_909|>", "<|LOC_910|>", "<|LOC_911|>", "<|LOC_912|>", "<|LOC_913|>", "<|LOC_914|>", "<|LOC_915|>", "<|LOC_916|>", "<|LOC_917|>", "<|LOC_918|>", "<|LOC_919|>", "<|LOC_920|>", "<|LOC_921|>", "<|LOC_922|>", "<|LOC_923|>", "<|LOC_924|>", "<|LOC_925|>", "<|LOC_926|>", "<|LOC_927|>", "<|LOC_928|>", "<|LOC_929|>", "<|LOC_930|>", "<|LOC_931|>", "<|LOC_932|>", "<|LOC_933|>", "<|LOC_934|>", "<|LOC_935|>", "<|LOC_936|>", "<|LOC_937|>", "<|LOC_938|>", "<|LOC_939|>", "<|LOC_940|>", "<|LOC_941|>", "<|LOC_942|>", "<|LOC_943|>", "<|LOC_944|>", "<|LOC_945|>", "<|LOC_946|>", "<|LOC_947|>", "<|LOC_948|>", "<|LOC_949|>", "<|LOC_950|>", "<|LOC_951|>", "<|LOC_952|>", "<|LOC_953|>", "<|LOC_954|>", "<|LOC_955|>", "<|LOC_956|>", "<|LOC_957|>", "<|LOC_958|>", "<|LOC_959|>", "<|LOC_960|>", "<|LOC_961|>", "<|LOC_962|>", "<|LOC_963|>", "<|LOC_964|>", "<|LOC_965|>", "<|LOC_966|>", "<|LOC_967|>", "<|LOC_968|>", "<|LOC_969|>", "<|LOC_970|>", "<|LOC_971|>", "<|LOC_972|>", "<|LOC_973|>", "<|LOC_974|>", "<|LOC_975|>", "<|LOC_976|>", "<|LOC_977|>", "<|LOC_978|>", "<|LOC_979|>", "<|LOC_980|>", "<|LOC_981|>", "<|LOC_982|>", "<|LOC_983|>", "<|LOC_984|>", "<|LOC_985|>", "<|LOC_986|>", "<|LOC_987|>", "<|LOC_988|>", "<|LOC_989|>", "<|LOC_990|>", "<|LOC_991|>", "<|LOC_992|>", "<|LOC_993|>", "<|LOC_994|>", "<|LOC_995|>", "<|LOC_996|>", "<|LOC_997|>", "<|LOC_998|>", "<|LOC_999|>", "<|LOC_1000|>", "<|LOC_BEGIN|>", "<|LOC_END|>", "<|LOC_SEP|>", "<|CROP_COL_SEP|>", "<|CROP_ROW_SEP|>", "<|IMAGE_SEP|>"]}
|
tokenization_ernie4_5.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from shutil import copyfile
|
| 17 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
import sentencepiece as spm
|
| 21 |
+
|
| 22 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 23 |
+
from transformers.tokenization_utils_base import (
|
| 24 |
+
PaddingStrategy,
|
| 25 |
+
)
|
| 26 |
+
from transformers.utils import logging
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Ernie4_5_Tokenizer(PreTrainedTokenizer):
|
| 33 |
+
|
| 34 |
+
vocab_files_names = {
|
| 35 |
+
"vocab_file": "tokenizer.model",
|
| 36 |
+
}
|
| 37 |
+
# Model input names expected by the tokenizer
|
| 38 |
+
model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
|
| 39 |
+
# Padding side (where to add padding tokens)
|
| 40 |
+
padding_side = "right"
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
vocab_file,
|
| 45 |
+
bos_token="<s>",
|
| 46 |
+
cls_token="<cls>",
|
| 47 |
+
eos_token="</s>",
|
| 48 |
+
mask_token="<mask:0>",
|
| 49 |
+
pad_token="<pad>",
|
| 50 |
+
sep_token="<sep>",
|
| 51 |
+
unk_token="<unk>",
|
| 52 |
+
additional_special_tokens=None,
|
| 53 |
+
split_special_tokens=False,
|
| 54 |
+
tokenizer_alpha=None,
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Initialize the ERNIE tokenizer.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
vocab_file (str): Path to the SentencePiece model file.
|
| 62 |
+
bos_token (str, optional): Beginning of sentence token. Defaults to "<s>".
|
| 63 |
+
cls_token (str, optional): Classification token. Defaults to "<cls>".
|
| 64 |
+
eos_token (str, optional): End of sentence token. Defaults to "</s>".
|
| 65 |
+
mask_token (str, optional): Mask token. Defaults to "<mask:0>".
|
| 66 |
+
pad_token (str, optional): Padding token. Defaults to "<pad>".
|
| 67 |
+
sep_token (str, optional): Separator token. Defaults to "<sep>".
|
| 68 |
+
unk_token (str, optional): Unknown token. Defaults to "<unk>".
|
| 69 |
+
additional_special_tokens (List[str], optional): Additional special tokens.
|
| 70 |
+
Defaults to ["<mask:1>", "<mask:7>"].
|
| 71 |
+
split_special_tokens (bool, optional): Whether to split special tokens. Defaults to False.
|
| 72 |
+
tokenizer_alpha (float, optional): Alpha parameter for SentencePiece sampling.
|
| 73 |
+
**kwargs: Additional keyword arguments passed to the parent class.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
self.vocab_file = vocab_file
|
| 77 |
+
self.sp_model = spm.SentencePieceProcessor()
|
| 78 |
+
self.sp_model.Load(vocab_file)
|
| 79 |
+
self.pad_id = self._convert_token_to_id(pad_token)
|
| 80 |
+
self.tokenizer_alpha = tokenizer_alpha
|
| 81 |
+
|
| 82 |
+
if additional_special_tokens is None:
|
| 83 |
+
additional_special_tokens = ["<mask:1>", "<mask:7>"]
|
| 84 |
+
super().__init__(
|
| 85 |
+
bos_token=bos_token,
|
| 86 |
+
cls_token=cls_token,
|
| 87 |
+
eos_token=eos_token,
|
| 88 |
+
mask_token=mask_token,
|
| 89 |
+
pad_token=pad_token,
|
| 90 |
+
sep_token=sep_token,
|
| 91 |
+
unk_token=unk_token,
|
| 92 |
+
additional_special_tokens=additional_special_tokens,
|
| 93 |
+
split_special_tokens=split_special_tokens,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def vocab_size(self):
|
| 99 |
+
"""Returns the size of the vocabulary.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
int: The number of tokens in the vocabulary.
|
| 103 |
+
"""
|
| 104 |
+
return self.sp_model.vocab_size()
|
| 105 |
+
|
| 106 |
+
def get_vocab(self):
|
| 107 |
+
"""Get the vocabulary as a dictionary mapping tokens to their IDs.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
dict: A dictionary mapping tokens to their corresponding IDs.
|
| 111 |
+
"""
|
| 112 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 113 |
+
vocab.update(self.added_tokens_encoder)
|
| 114 |
+
return vocab
|
| 115 |
+
|
| 116 |
+
def _tokenize(self, text):
|
| 117 |
+
"""Tokenize text using SentencePiece.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
text (str): The text to tokenize.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
list: A list of tokens.
|
| 124 |
+
"""
|
| 125 |
+
if self.tokenizer_alpha is not None:
|
| 126 |
+
return self.sp_model.encode_as_pieces(
|
| 127 |
+
text,
|
| 128 |
+
enable_sampling=True,
|
| 129 |
+
nbest_size=-1,
|
| 130 |
+
alpha=self.tokenizer_alpha,
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
return self.sp_model.encode_as_pieces(text)
|
| 134 |
+
|
| 135 |
+
def _convert_token_to_id(self, token):
|
| 136 |
+
"""Convert a token (str) to an ID using the vocabulary.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
token (str): The token to convert.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
int: The corresponding token ID.
|
| 143 |
+
"""
|
| 144 |
+
return self.sp_model.piece_to_id(token)
|
| 145 |
+
|
| 146 |
+
def _convert_id_to_token(self, id):
|
| 147 |
+
"""Convert an ID to a token (str) using the vocabulary.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
id (int): The token ID to convert.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
str: The corresponding token.
|
| 154 |
+
"""
|
| 155 |
+
if id >= self.vocab_size:
|
| 156 |
+
return self.unk_token
|
| 157 |
+
else:
|
| 158 |
+
return self.sp_model.id_to_piece(id)
|
| 159 |
+
|
| 160 |
+
def convert_tokens_to_string(self, tokens):
|
| 161 |
+
"""Convert a sequence of tokens back to a single string.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
tokens (List[str]): A list of tokens to convert.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
str: The reconstructed string.
|
| 168 |
+
"""
|
| 169 |
+
current_sub_tokens = []
|
| 170 |
+
out_string = ""
|
| 171 |
+
prev_is_special = False
|
| 172 |
+
for token in tokens:
|
| 173 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 174 |
+
if token in self.all_special_tokens:
|
| 175 |
+
if not prev_is_special:
|
| 176 |
+
out_string += " "
|
| 177 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 178 |
+
prev_is_special = True
|
| 179 |
+
current_sub_tokens = []
|
| 180 |
+
else:
|
| 181 |
+
current_sub_tokens.append(token)
|
| 182 |
+
prev_is_special = False
|
| 183 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 184 |
+
return out_string
|
| 185 |
+
|
| 186 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 187 |
+
"""Build model inputs by adding special tokens to sequences.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
token_ids_0 (List[int]): List of token IDs for the first sequence.
|
| 191 |
+
token_ids_1 (List[int], optional): List of token IDs for the second sequence.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
List[int]: List of token IDs with special tokens added.
|
| 195 |
+
"""
|
| 196 |
+
output = token_ids_0
|
| 197 |
+
last_cls_index = -1
|
| 198 |
+
last_sep_index = -1
|
| 199 |
+
if self.cls_token_id in output:
|
| 200 |
+
last_cls_index = len(output) - output[::-1].index(self.cls_token_id) - 1
|
| 201 |
+
if self.sep_token_id in output:
|
| 202 |
+
last_sep_index = len(output) - output[::-1].index(self.sep_token_id) - 1
|
| 203 |
+
|
| 204 |
+
if last_cls_index > last_sep_index:
|
| 205 |
+
next_token_id = self.sep_token_id
|
| 206 |
+
elif last_sep_index > last_cls_index:
|
| 207 |
+
next_token_id = self.cls_token_id
|
| 208 |
+
else:
|
| 209 |
+
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 210 |
+
next_token_id = self.cls_token_id
|
| 211 |
+
|
| 212 |
+
output = [self.bos_token_id] + output
|
| 213 |
+
# Assume no markup in text if token_ids_1 is given.
|
| 214 |
+
if token_ids_1 is not None:
|
| 215 |
+
output = output + token_ids_1 + [next_token_id]
|
| 216 |
+
return output
|
| 217 |
+
|
| 218 |
+
def get_special_tokens_mask(
|
| 219 |
+
self, token_ids_0, token_ids_1=None, already_has_special_tokens=False
|
| 220 |
+
):
|
| 221 |
+
"""Get a mask showing which tokens are special tokens.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
token_ids_0 (List[int]): List of token IDs for the first sequence.
|
| 225 |
+
token_ids_1 (List[int], optional): List of token IDs for the second sequence.
|
| 226 |
+
already_has_special_tokens (bool): Whether the tokens already include special tokens.
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
List[int]: A mask where 1 indicates special tokens and 0 indicates regular tokens.
|
| 230 |
+
"""
|
| 231 |
+
if already_has_special_tokens:
|
| 232 |
+
return super().get_special_tokens_mask(
|
| 233 |
+
token_ids_0, token_ids_1, already_has_special_tokens=True
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# [bos_token, cls_token, tokens_0, sep_token]
|
| 237 |
+
if token_ids_1 is None:
|
| 238 |
+
return [1, 1] + ([0] * len(token_ids_0)) + [1]
|
| 239 |
+
# [bos_token, cls_token, tokens_0, sep_token, tokens_1, cls_token]
|
| 240 |
+
return [1, 1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 241 |
+
|
| 242 |
+
def save_vocabulary(
|
| 243 |
+
self, save_directory, filename_prefix: Optional[str] = None
|
| 244 |
+
) -> Tuple[str]:
|
| 245 |
+
"""
|
| 246 |
+
Save the vocabulary and special tokens file to a directory.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
save_directory (str): The directory in which to save the vocabulary.
|
| 250 |
+
filename_prefix (Optional[str]): Optional prefix for the saved filename.
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
Tuple[str]: Paths to the files saved.
|
| 254 |
+
|
| 255 |
+
Raises:
|
| 256 |
+
ValueError: If the save_directory is not a valid directory.
|
| 257 |
+
"""
|
| 258 |
+
if not os.path.isdir(save_directory):
|
| 259 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 260 |
+
return
|
| 261 |
+
out_vocab_file = os.path.join(
|
| 262 |
+
save_directory,
|
| 263 |
+
(filename_prefix + "-" if filename_prefix else "")
|
| 264 |
+
+ self.resource_files_names["vocab_file"],
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
| 268 |
+
out_vocab_file
|
| 269 |
+
) and os.path.isfile(self.vocab_file):
|
| 270 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 271 |
+
elif not os.path.isfile(self.vocab_file):
|
| 272 |
+
with open(out_vocab_file, "wb") as fi:
|
| 273 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 274 |
+
fi.write(content_spiece_model)
|
| 275 |
+
|
| 276 |
+
return (out_vocab_file,)
|
| 277 |
+
|
| 278 |
+
def _pad(
|
| 279 |
+
self,
|
| 280 |
+
encoded_inputs: Union[Dict],
|
| 281 |
+
max_length: Optional[int] = None,
|
| 282 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 283 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 284 |
+
padding_side: Optional[str] = None,
|
| 285 |
+
return_attention_mask: Optional[bool] = None,
|
| 286 |
+
) -> dict:
|
| 287 |
+
"""
|
| 288 |
+
Pad encoded inputs according to specified strategy.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
encoded_inputs (Union[Dict]): Dictionary of encoded inputs.
|
| 292 |
+
max_length (Optional[int]): Maximum length to pad to.
|
| 293 |
+
padding_strategy (PaddingStrategy): Strategy for padding.
|
| 294 |
+
pad_to_multiple_of (Optional[int]): Pad to a multiple of this value.
|
| 295 |
+
return_attention_mask (Optional[bool]): Whether to return attention mask.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
dict: Dictionary with padded inputs and optional attention mask.
|
| 299 |
+
|
| 300 |
+
Raises:
|
| 301 |
+
ValueError: If attention_mask has unexpected type or invalid padding strategy.
|
| 302 |
+
"""
|
| 303 |
+
if return_attention_mask is None:
|
| 304 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
| 305 |
+
if return_attention_mask:
|
| 306 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
| 307 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
| 308 |
+
max_length = len(required_input)
|
| 309 |
+
if (
|
| 310 |
+
max_length is not None
|
| 311 |
+
and pad_to_multiple_of is not None
|
| 312 |
+
and (max_length % pad_to_multiple_of != 0)
|
| 313 |
+
):
|
| 314 |
+
max_length = (
|
| 315 |
+
(max_length // pad_to_multiple_of) + 1
|
| 316 |
+
) * pad_to_multiple_of
|
| 317 |
+
needs_to_be_padded = (
|
| 318 |
+
padding_strategy != PaddingStrategy.DO_NOT_PAD
|
| 319 |
+
and len(required_input) != max_length
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
if (
|
| 323 |
+
"attention_mask" in encoded_inputs
|
| 324 |
+
and encoded_inputs["attention_mask"] is not None
|
| 325 |
+
):
|
| 326 |
+
attention_mask = encoded_inputs.pop("attention_mask")
|
| 327 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 328 |
+
attention_mask = attention_mask.numpy()
|
| 329 |
+
elif isinstance(attention_mask, list):
|
| 330 |
+
attention_mask = np.array(attention_mask)
|
| 331 |
+
elif not isinstance(attention_mask, np.ndarray):
|
| 332 |
+
raise ValueError(
|
| 333 |
+
f"Unexpected type {type(attention_mask)} of attention_mask, "
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
# Create default attention mask if none provided
|
| 337 |
+
attention_mask = np.tril(
|
| 338 |
+
np.ones((len(required_input), len(required_input)), dtype=np.int64)
|
| 339 |
+
)
|
| 340 |
+
attention_mask = np.expand_dims(attention_mask, axis=0)
|
| 341 |
+
|
| 342 |
+
if needs_to_be_padded:
|
| 343 |
+
difference = max_length - len(required_input)
|
| 344 |
+
if self.padding_side == "right":
|
| 345 |
+
if attention_mask.ndim == 1:
|
| 346 |
+
pad_width = [(0, difference)]
|
| 347 |
+
else:
|
| 348 |
+
pad_width = [(0, 0), (0, difference), (0, difference)]
|
| 349 |
+
elif self.padding_side == "left":
|
| 350 |
+
if attention_mask.ndim == 1:
|
| 351 |
+
pad_width = [(difference, 0)]
|
| 352 |
+
else:
|
| 353 |
+
pad_width = [(0, 0), (difference, 0), (difference, 0)]
|
| 354 |
+
else:
|
| 355 |
+
raise ValueError(
|
| 356 |
+
"Invalid padding strategy:" + str(self.padding_side)
|
| 357 |
+
)
|
| 358 |
+
attention_mask = np.pad(
|
| 359 |
+
attention_mask,
|
| 360 |
+
pad_width=pad_width,
|
| 361 |
+
mode="constant",
|
| 362 |
+
constant_values=0,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
encoded_inputs = super()._pad(
|
| 366 |
+
encoded_inputs,
|
| 367 |
+
max_length,
|
| 368 |
+
padding_strategy=padding_strategy,
|
| 369 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 370 |
+
return_attention_mask=False,
|
| 371 |
+
)
|
| 372 |
+
if return_attention_mask:
|
| 373 |
+
encoded_inputs["attention_mask"] = attention_mask.tolist()
|
| 374 |
+
return encoded_inputs
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34ef7db83df785924fb83d7b887b6e822a031c56e15cff40aaf9b982988180df
|
| 3 |
+
size 1614363
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<s>",
|
| 3 |
+
"eos_token": "</s>",
|
| 4 |
+
"pad_token": "<unk>",
|
| 5 |
+
"unk_token": "<unk>",
|
| 6 |
+
"cls_token": "<|begin_of_sentence|>",
|
| 7 |
+
"sep_token": "<|end_of_sentence|>",
|
| 8 |
+
"mask_token": "<mask:1>",
|
| 9 |
+
"sys_start_token": "<mask:4>",
|
| 10 |
+
"sys_end_token": "<mask:5>",
|
| 11 |
+
"header_start_token": "<mask:6>",
|
| 12 |
+
"header_end_token": "<mask:7>",
|
| 13 |
+
"additional_special_tokens": null,
|
| 14 |
+
"tokenizer_class": "Ernie4_5_Tokenizer",
|
| 15 |
+
"auto_map": {
|
| 16 |
+
"AutoTokenizer": [
|
| 17 |
+
"tokenization_ernie4_5.Ernie4_5_Tokenizer",
|
| 18 |
+
null
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
"chat_template": "{%- if not add_generation_prompt is defined -%}\n {%- set add_generation_prompt = true -%}\n{%- endif -%}\n{%- if not cls_token is defined -%}\n {%- set cls_token = \"<|begin_of_sentence|>\" -%}\n{%- endif -%}\n{%- if not sep_token is defined -%}\n {%- set sep_token = \"<|end_of_sentence|>\" -%}\n{%- endif -%}\n{{- cls_token -}}\n{%- for message in messages -%}\n {%- if message[\"role\"] == \"user\" -%}\n {{- \"User: \" + message[\"content\"] + \"\n\" -}}\n {%- elif message[\"role\"] == \"assistant\" -%}\n {{- \"Assistant: \" + message[\"content\"] + sep_token -}}\n {%- elif message[\"role\"] == \"system\" -%}\n {{- message[\"content\"] + \"\n\" -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{- \"Assistant: \" -}}\n{%- endif -%}"
|
| 22 |
+
}
|